|
|
@ -20,6 +20,7 @@
|
|
|
|
#include "nnacl/fp16/activation_fp16.h"
|
|
|
|
#include "nnacl/fp16/activation_fp16.h"
|
|
|
|
#include "nnacl/fp16/arithmetic_fp16.h"
|
|
|
|
#include "nnacl/fp16/arithmetic_fp16.h"
|
|
|
|
#include "nnacl/fp16/matmul_fp16.h"
|
|
|
|
#include "nnacl/fp16/matmul_fp16.h"
|
|
|
|
|
|
|
|
#include "nnacl/fp16/cast_fp16.h"
|
|
|
|
|
|
|
|
|
|
|
|
void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) {
|
|
|
|
void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) {
|
|
|
|
for (int i = 0; i < batch; i++) {
|
|
|
|
for (int i = 0; i < batch; i++) {
|
|
|
@ -37,6 +38,43 @@ void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int dee
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align,
|
|
|
|
|
|
|
|
bool is_bidirectional) {
|
|
|
|
|
|
|
|
int unidirectional_batch = is_bidirectional ? batch / 2 : batch;
|
|
|
|
|
|
|
|
for (int i = 0; i < unidirectional_batch; i++) {
|
|
|
|
|
|
|
|
const float *src_batch = src + i * col;
|
|
|
|
|
|
|
|
float16_t *dst_batch = dst + i * col_align;
|
|
|
|
|
|
|
|
Float32ToFloat16(src_batch, dst_batch, col);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (is_bidirectional) {
|
|
|
|
|
|
|
|
const float *backward_src = src + batch * col;
|
|
|
|
|
|
|
|
float16_t *backward_dst = dst + unidirectional_batch * col_align;
|
|
|
|
|
|
|
|
for (int i = 0; i < unidirectional_batch; i++) {
|
|
|
|
|
|
|
|
const float *backward_src_batch = backward_src + i * col;
|
|
|
|
|
|
|
|
float16_t *backward_dst_batch = backward_dst + i * col_align;
|
|
|
|
|
|
|
|
Float32ToFloat16(backward_src_batch, backward_dst_batch, col);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional) {
|
|
|
|
|
|
|
|
int unidirectional_batch = is_bidirectional ? batch / 2 : batch;
|
|
|
|
|
|
|
|
for (int i = 0; i < unidirectional_batch; i++) {
|
|
|
|
|
|
|
|
const float16_t *src_batch = src + i * col;
|
|
|
|
|
|
|
|
float16_t *dst_batch = dst + i * col_align;
|
|
|
|
|
|
|
|
memcpy(dst_batch, src_batch, col * sizeof(float16_t));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (is_bidirectional) {
|
|
|
|
|
|
|
|
const float16_t *backward_src = src + batch * col;
|
|
|
|
|
|
|
|
float16_t *backward_dst = dst + unidirectional_batch * col_align;
|
|
|
|
|
|
|
|
for (int i = 0; i < unidirectional_batch; i++) {
|
|
|
|
|
|
|
|
const float16_t *backward_src_batch = backward_src + i * col;
|
|
|
|
|
|
|
|
float16_t *backward_dst_batch = backward_dst + i * col_align;
|
|
|
|
|
|
|
|
memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float16_t));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col]
|
|
|
|
// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col]
|
|
|
|
void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols,
|
|
|
|
void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols,
|
|
|
|
int inner_size) {
|
|
|
|
int inner_size) {
|
|
|
@ -149,40 +187,32 @@ void UpdateLstmGateFp16(float16_t *gate_buffer, const float16_t *input, const fl
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_weight,
|
|
|
|
void LstmStepUnitFp16(float16_t *output, float16_t *input_gate, float16_t *forget_gate, float16_t *cell_gate,
|
|
|
|
const float16_t *state_weight, const float16_t *bias, float16_t *hidden_state,
|
|
|
|
float16_t *output_gate, const float16_t *state_weight, const float16_t *state_bias,
|
|
|
|
float16_t *cell_state, float16_t *gate_buffer, float16_t *state_buffer[2],
|
|
|
|
float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[6],
|
|
|
|
float16_t *matmul_buffer[2], const LstmParameter *lstm_param) {
|
|
|
|
const LstmParameter *lstm_param) {
|
|
|
|
|
|
|
|
float16_t *packed_state = buffer[2];
|
|
|
|
|
|
|
|
float16_t *state_gate = buffer[3];
|
|
|
|
|
|
|
|
float16_t *cell_buffer = buffer[4];
|
|
|
|
|
|
|
|
float16_t *hidden_buffer = buffer[5];
|
|
|
|
bool is_vec = lstm_param->batch_ == 1;
|
|
|
|
bool is_vec = lstm_param->batch_ == 1;
|
|
|
|
// input * weight
|
|
|
|
|
|
|
|
if (is_vec) {
|
|
|
|
|
|
|
|
UpdateLstmGateFp16(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
|
|
|
|
|
|
|
|
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
// pack input for matmul
|
|
|
|
|
|
|
|
RowMajor2Col16MajorFp16(input, matmul_buffer[0], lstm_param->batch_, lstm_param->input_size_, false);
|
|
|
|
|
|
|
|
UpdateLstmGateFp16(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
|
|
|
|
|
|
|
|
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// state * weight
|
|
|
|
|
|
|
|
float16_t *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4;
|
|
|
|
|
|
|
|
const float16_t *state_bias = bias + lstm_param->col_align_ * 4;
|
|
|
|
|
|
|
|
if (is_vec) {
|
|
|
|
if (is_vec) {
|
|
|
|
UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
|
|
|
UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
|
|
|
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
|
|
|
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
// pack state for matmul
|
|
|
|
// pack state for matmul
|
|
|
|
RowMajor2Col16MajorFp16(hidden_state, matmul_buffer[1], lstm_param->batch_, lstm_param->hidden_size_, false);
|
|
|
|
RowMajor2Col16MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->hidden_size_, false);
|
|
|
|
UpdateLstmGateFp16(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_,
|
|
|
|
UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
|
|
|
lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
|
|
|
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ElementAddFp16(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_);
|
|
|
|
ElementAddFp16(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
|
|
|
|
|
|
|
ElementAddFp16(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate,
|
|
|
|
|
|
|
|
lstm_param->batch_ * lstm_param->hidden_size_);
|
|
|
|
|
|
|
|
ElementAddFp16(cell_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 3, cell_gate,
|
|
|
|
|
|
|
|
lstm_param->batch_ * lstm_param->hidden_size_);
|
|
|
|
|
|
|
|
ElementAddFp16(output_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_, output_gate,
|
|
|
|
|
|
|
|
lstm_param->batch_ * lstm_param->hidden_size_);
|
|
|
|
|
|
|
|
|
|
|
|
float16_t *input_gate = gate_buffer;
|
|
|
|
|
|
|
|
float16_t *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2;
|
|
|
|
|
|
|
|
float16_t *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3;
|
|
|
|
|
|
|
|
float16_t *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_;
|
|
|
|
|
|
|
|
// update input_gate
|
|
|
|
// update input_gate
|
|
|
|
SigmoidFp16(input_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
|
|
|
SigmoidFp16(input_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
|
|
|
|
|
|
|
|
|
|
@ -192,50 +222,76 @@ void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t
|
|
|
|
// update cell_gate
|
|
|
|
// update cell_gate
|
|
|
|
TanhFp16(cell_gate, cell_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
|
|
|
TanhFp16(cell_gate, cell_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
|
|
|
// update cell state
|
|
|
|
// update cell state
|
|
|
|
UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, state_buffer[0], lstm_param->batch_,
|
|
|
|
UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, cell_buffer, lstm_param->batch_,
|
|
|
|
lstm_param->hidden_size_, lstm_param->zoneout_cell_);
|
|
|
|
lstm_param->hidden_size_, lstm_param->zoneout_cell_);
|
|
|
|
|
|
|
|
|
|
|
|
// update output_gate
|
|
|
|
// update output_gate
|
|
|
|
SigmoidFp16(output_gate, output_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
|
|
|
SigmoidFp16(output_gate, output_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
|
|
|
// update output
|
|
|
|
// update output
|
|
|
|
UpdataOutputFp16(cell_state, output_gate, hidden_state, state_buffer[1], lstm_param->batch_, lstm_param->hidden_size_,
|
|
|
|
UpdataOutputFp16(cell_state, output_gate, hidden_state, hidden_buffer, lstm_param->batch_, lstm_param->hidden_size_,
|
|
|
|
lstm_param->zoneout_hidden_);
|
|
|
|
lstm_param->zoneout_hidden_);
|
|
|
|
memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
|
|
|
memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
|
|
|
|
|
|
|
|
|
|
|
if (!(lstm_param->zoneout_cell_ >= -FLT_EPSILON && lstm_param->zoneout_cell_ <= FLT_EPSILON)) {
|
|
|
|
if (!(lstm_param->zoneout_cell_ >= -FLT_EPSILON && lstm_param->zoneout_cell_ <= FLT_EPSILON)) {
|
|
|
|
memcpy(cell_state, state_buffer[0], lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
|
|
|
memcpy(cell_state, cell_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) {
|
|
|
|
if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) {
|
|
|
|
memcpy(hidden_state, state_buffer[1], lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
|
|
|
memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h,
|
|
|
|
void LstmUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_i,
|
|
|
|
const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer,
|
|
|
|
const float16_t *weight_h, const float16_t *input_bias, const float16_t *state_bias,
|
|
|
|
float16_t *state_buffer[2], float16_t *matmul_buffer[2], const LstmParameter *lstm_param) {
|
|
|
|
float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[6],
|
|
|
|
// forward
|
|
|
|
const LstmParameter *lstm_param, bool is_backward) {
|
|
|
|
|
|
|
|
float16_t *gate = buffer[1];
|
|
|
|
|
|
|
|
for (int i = 0; i < 4; i++) {
|
|
|
|
|
|
|
|
const float16_t *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i;
|
|
|
|
|
|
|
|
const float16_t *bias_loop = input_bias + lstm_param->input_col_align_ * i;
|
|
|
|
|
|
|
|
float16_t *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i;
|
|
|
|
|
|
|
|
MatMulFp16(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_,
|
|
|
|
|
|
|
|
lstm_param->seq_len_ * lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_,
|
|
|
|
|
|
|
|
OutType_Nhwc);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
float16_t *input_gate = gate;
|
|
|
|
|
|
|
|
float16_t *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2;
|
|
|
|
|
|
|
|
float16_t *cell_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 3;
|
|
|
|
|
|
|
|
float16_t *output_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_;
|
|
|
|
for (int t = 0; t < lstm_param->seq_len_; t++) {
|
|
|
|
for (int t = 0; t < lstm_param->seq_len_; t++) {
|
|
|
|
const float16_t *input_ptr = input + t * lstm_param->input_step_;
|
|
|
|
int real_t = is_backward ? lstm_param->seq_len_ - t - 1 : t;
|
|
|
|
float16_t *output_ptr = output + t * lstm_param->output_step_;
|
|
|
|
float16_t *input_gate_t = input_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
|
|
|
|
LstmStepUnitFp16(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer,
|
|
|
|
float16_t *forget_gate_t = forget_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
|
|
|
|
state_buffer, matmul_buffer, lstm_param);
|
|
|
|
float16_t *cell_gate_t = cell_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
|
|
|
|
|
|
|
|
float16_t *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
|
|
|
|
|
|
|
|
float16_t *output_ptr = output + real_t * lstm_param->output_step_;
|
|
|
|
|
|
|
|
LstmStepUnitFp16(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias,
|
|
|
|
|
|
|
|
hidden_state, cell_state, buffer, lstm_param);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h,
|
|
|
|
|
|
|
|
const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *cell_state,
|
|
|
|
|
|
|
|
float16_t *buffer[6], const LstmParameter *lstm_param) {
|
|
|
|
|
|
|
|
// forward
|
|
|
|
|
|
|
|
float16_t *packed_input = buffer[0];
|
|
|
|
|
|
|
|
RowMajor2Col16MajorFp16(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_,
|
|
|
|
|
|
|
|
false);
|
|
|
|
|
|
|
|
LstmUnidirectionalFp16(output, packed_input, weight_i, weight_h, input_bias, state_bias, hidden_state, cell_state,
|
|
|
|
|
|
|
|
buffer, lstm_param, false);
|
|
|
|
|
|
|
|
|
|
|
|
// backward
|
|
|
|
// backward
|
|
|
|
if (lstm_param->bidirectional_) {
|
|
|
|
if (lstm_param->bidirectional_) {
|
|
|
|
const float16_t *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_;
|
|
|
|
const float16_t *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_;
|
|
|
|
const float16_t *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_;
|
|
|
|
const float16_t *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->hidden_size_;
|
|
|
|
const float16_t *backward_bias = bias + 8 * lstm_param->col_align_;
|
|
|
|
const float16_t *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_;
|
|
|
|
|
|
|
|
const float16_t *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_;
|
|
|
|
float16_t *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_;
|
|
|
|
float16_t *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_;
|
|
|
|
float16_t *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
|
|
|
float16_t *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
|
|
|
float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
|
|
|
float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
|
|
|
for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) {
|
|
|
|
|
|
|
|
const float16_t *input_ptr = input + t * lstm_param->input_step_;
|
|
|
|
LstmUnidirectionalFp16(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias,
|
|
|
|
float16_t *output_ptr = backward_output + t * lstm_param->output_step_;
|
|
|
|
backward_state_bias, backward_hidden_state, backward_cell_state, buffer, lstm_param, true);
|
|
|
|
LstmStepUnitFp16(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias,
|
|
|
|
|
|
|
|
backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, matmul_buffer,
|
|
|
|
|
|
|
|
lstm_param);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|