|
|
|
@ -19,20 +19,7 @@
|
|
|
|
|
#include <float.h>
|
|
|
|
|
#include "nnacl/fp32/activation_fp32.h"
|
|
|
|
|
#include "nnacl/fp32/arithmetic_fp32.h"
|
|
|
|
|
#include "nnacl/fp32/mul_fp32.h"
|
|
|
|
|
|
|
|
|
|
void InitGate(float *gate_buffer, const float *bias, const LstmParameter *lstm_parm) {
|
|
|
|
|
int gate_offest = 0;
|
|
|
|
|
for (int l = 0; l < 4; l++) {
|
|
|
|
|
int batch_offest = gate_offest;
|
|
|
|
|
int bias_offest = l * lstm_parm->hidden_size_;
|
|
|
|
|
for (int b = 0; b < lstm_parm->batch_; b++) {
|
|
|
|
|
memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float));
|
|
|
|
|
batch_offest += lstm_parm->hidden_size_;
|
|
|
|
|
}
|
|
|
|
|
gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#include "nnacl/fp32/matmul_fp32.h"
|
|
|
|
|
|
|
|
|
|
// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col]
|
|
|
|
|
void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) {
|
|
|
|
@ -134,106 +121,131 @@ void UpdataOutput(const float *cell_state, const float *output_gate, float *hidd
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight,
|
|
|
|
|
const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight,
|
|
|
|
|
const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight,
|
|
|
|
|
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
|
|
|
|
|
const LstmParameter *lstm_parm) {
|
|
|
|
|
InitGate(gate_buffer, bias, lstm_parm);
|
|
|
|
|
void LstmMatmul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) {
|
|
|
|
|
if (is_vec) {
|
|
|
|
|
memcpy(c, bias, col * sizeof(float));
|
|
|
|
|
MatMulAcc(c, a, b, row, col, deep);
|
|
|
|
|
} else {
|
|
|
|
|
MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PackLstmInput(float *dst, const float *src, int row, int deep) {
|
|
|
|
|
#ifdef ENABLE_AVX
|
|
|
|
|
RowMajor2Col6Major(src, dst, row, deep);
|
|
|
|
|
#elif defined(ENABLE_SSE)
|
|
|
|
|
RowMajor2Col4Major(src, dst, row, deep);
|
|
|
|
|
#else
|
|
|
|
|
RowMajor2Col12Major(src, dst, row, deep);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void UpdateGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep,
|
|
|
|
|
int col, int col_align, bool is_vec) {
|
|
|
|
|
const float *input_weight = weight;
|
|
|
|
|
const float *forget_weight = weight + deep * col * 2;
|
|
|
|
|
const float *cell_weight = weight + deep * col * 3;
|
|
|
|
|
const float *output_weight = weight + deep * col;
|
|
|
|
|
|
|
|
|
|
const float *input_bias = bias;
|
|
|
|
|
const float *forget_bias = bias + col_align * 2;
|
|
|
|
|
const float *cell_bias = bias + col_align * 3;
|
|
|
|
|
const float *output_bias = bias + col_align;
|
|
|
|
|
|
|
|
|
|
float *input_gate = gate_buffer;
|
|
|
|
|
float *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2;
|
|
|
|
|
float *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3;
|
|
|
|
|
float *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1;
|
|
|
|
|
float *forget_gate = gate_buffer + row * col * 2;
|
|
|
|
|
float *cell_gate = gate_buffer + row * col * 3;
|
|
|
|
|
float *output_gate = gate_buffer + row * col;
|
|
|
|
|
|
|
|
|
|
LstmMatmul(input_gate, input, input_weight, input_bias, row, deep, col, is_vec);
|
|
|
|
|
LstmMatmul(forget_gate, input, forget_weight, forget_bias, row, deep, col, is_vec);
|
|
|
|
|
LstmMatmul(cell_gate, input, cell_weight, cell_bias, row, deep, col, is_vec);
|
|
|
|
|
LstmMatmul(output_gate, input, output_weight, output_bias, row, deep, col, is_vec);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LstmStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight,
|
|
|
|
|
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
|
|
|
|
|
float *matmul_buffer[2], const LstmParameter *lstm_param) {
|
|
|
|
|
bool is_vec = lstm_param->batch_ == 1;
|
|
|
|
|
// input * weight
|
|
|
|
|
MatMulAcc(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_);
|
|
|
|
|
MatMulAcc(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
|
|
|
|
lstm_parm->input_size_);
|
|
|
|
|
MatMulAcc(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_);
|
|
|
|
|
MatMulAcc(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
|
|
|
|
lstm_parm->input_size_);
|
|
|
|
|
if (is_vec) {
|
|
|
|
|
UpdateGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
|
|
|
|
|
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
|
|
|
|
} else {
|
|
|
|
|
// pack input for matmul
|
|
|
|
|
PackLstmInput(matmul_buffer[0], input, lstm_param->batch_, lstm_param->input_size_);
|
|
|
|
|
UpdateGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
|
|
|
|
|
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// state * weight
|
|
|
|
|
MatMulAcc(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
|
|
|
|
lstm_parm->hidden_size_);
|
|
|
|
|
MatMulAcc(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
|
|
|
|
lstm_parm->hidden_size_);
|
|
|
|
|
MatMulAcc(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
|
|
|
|
lstm_parm->hidden_size_);
|
|
|
|
|
MatMulAcc(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
|
|
|
|
lstm_parm->hidden_size_);
|
|
|
|
|
float *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4;
|
|
|
|
|
const float *state_bias = bias + lstm_param->col_align_ * 4;
|
|
|
|
|
if (is_vec) {
|
|
|
|
|
UpdateGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
|
|
|
|
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
|
|
|
|
} else {
|
|
|
|
|
// pack state for matmul
|
|
|
|
|
PackLstmInput(matmul_buffer[1], hidden_state, lstm_param->batch_, lstm_param->hidden_size_);
|
|
|
|
|
UpdateGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
|
|
|
|
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
|
|
|
|
}
|
|
|
|
|
ElementAdd(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_);
|
|
|
|
|
|
|
|
|
|
float *input_gate = gate_buffer;
|
|
|
|
|
float *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2;
|
|
|
|
|
float *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3;
|
|
|
|
|
float *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_;
|
|
|
|
|
// update input_gate
|
|
|
|
|
Sigmoid(input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, input_gate);
|
|
|
|
|
Sigmoid(input_gate, lstm_param->batch_ * lstm_param->hidden_size_, input_gate);
|
|
|
|
|
|
|
|
|
|
// update forget_gate
|
|
|
|
|
Sigmoid(forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, forget_gate);
|
|
|
|
|
Sigmoid(forget_gate, lstm_param->batch_ * lstm_param->hidden_size_, forget_gate);
|
|
|
|
|
|
|
|
|
|
// update cell_gate
|
|
|
|
|
Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate);
|
|
|
|
|
Tanh(cell_gate, lstm_param->batch_ * lstm_param->hidden_size_, cell_gate);
|
|
|
|
|
// update cell state
|
|
|
|
|
UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
|
|
|
|
|
lstm_parm->smooth_);
|
|
|
|
|
UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_param->batch_,
|
|
|
|
|
lstm_param->hidden_size_, lstm_param->smooth_);
|
|
|
|
|
|
|
|
|
|
// update output_gate
|
|
|
|
|
Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate);
|
|
|
|
|
Sigmoid(output_gate, lstm_param->batch_ * lstm_param->hidden_size_, output_gate);
|
|
|
|
|
// update output
|
|
|
|
|
UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
|
|
|
|
|
lstm_parm->smooth_);
|
|
|
|
|
memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
|
|
|
|
|
|
|
|
|
|
if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) {
|
|
|
|
|
memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
|
|
|
|
|
memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_,
|
|
|
|
|
lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
|
|
|
|
|
UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_param->batch_, lstm_param->hidden_size_,
|
|
|
|
|
lstm_param->smooth_);
|
|
|
|
|
memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
|
|
|
|
|
|
|
|
|
|
if (!(lstm_param->smooth_ >= -FLT_EPSILON && lstm_param->smooth_ <= FLT_EPSILON)) {
|
|
|
|
|
memcpy(cell_state, state_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
|
|
|
|
|
memcpy(hidden_state, state_buffer + lstm_param->batch_ * lstm_param->hidden_size_,
|
|
|
|
|
lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
|
|
|
|
|
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
|
|
|
|
|
const LstmParameter *lstm_parm) {
|
|
|
|
|
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2],
|
|
|
|
|
const LstmParameter *lstm_param) {
|
|
|
|
|
// forward
|
|
|
|
|
const float *input_input_weight = weight_i;
|
|
|
|
|
const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2;
|
|
|
|
|
const float *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3;
|
|
|
|
|
const float *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1;
|
|
|
|
|
|
|
|
|
|
const float *state_input_weight = weight_h;
|
|
|
|
|
const float *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2;
|
|
|
|
|
const float *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3;
|
|
|
|
|
const float *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1;
|
|
|
|
|
|
|
|
|
|
for (int t = 0; t < lstm_parm->seq_len_; t++) {
|
|
|
|
|
const float *input_ptr = input + t * lstm_parm->input_step_;
|
|
|
|
|
float *output_ptr = output + t * lstm_parm->output_step_;
|
|
|
|
|
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight,
|
|
|
|
|
state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state,
|
|
|
|
|
cell_state, gate_buffer, state_buffer, lstm_parm);
|
|
|
|
|
for (int t = 0; t < lstm_param->seq_len_; t++) {
|
|
|
|
|
const float *input_ptr = input + t * lstm_param->input_step_;
|
|
|
|
|
float *output_ptr = output + t * lstm_param->output_step_;
|
|
|
|
|
LstmStepUnit(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer, state_buffer,
|
|
|
|
|
matmul_buffer, lstm_param);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// backward
|
|
|
|
|
if (lstm_parm->bidirectional_) {
|
|
|
|
|
input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4;
|
|
|
|
|
input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6;
|
|
|
|
|
input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7;
|
|
|
|
|
input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5;
|
|
|
|
|
|
|
|
|
|
state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4;
|
|
|
|
|
state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6;
|
|
|
|
|
state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7;
|
|
|
|
|
state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5;
|
|
|
|
|
|
|
|
|
|
float *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
|
|
|
|
const float *backward_bias = bias + 4 * lstm_parm->hidden_size_;
|
|
|
|
|
float *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
|
|
|
|
float *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
|
|
|
|
for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) {
|
|
|
|
|
const float *input_ptr = input + t * lstm_parm->input_step_;
|
|
|
|
|
float *output_ptr = backward_output + t * lstm_parm->output_step_;
|
|
|
|
|
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight,
|
|
|
|
|
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight,
|
|
|
|
|
backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, lstm_parm);
|
|
|
|
|
if (lstm_param->bidirectional_) {
|
|
|
|
|
const float *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_;
|
|
|
|
|
const float *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_;
|
|
|
|
|
const float *backward_bias = bias + 8 * lstm_param->hidden_size_;
|
|
|
|
|
float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_;
|
|
|
|
|
float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
|
|
|
|
float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
|
|
|
|
for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) {
|
|
|
|
|
const float *input_ptr = input + t * lstm_param->input_step_;
|
|
|
|
|
float *output_ptr = backward_output + t * lstm_param->output_step_;
|
|
|
|
|
LstmStepUnit(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias, backward_hidden_state,
|
|
|
|
|
backward_cell_state, gate_buffer, state_buffer, matmul_buffer, lstm_param);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|