|
|
|
@ -179,23 +179,23 @@ class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
|
|
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated */
|
|
|
|
|
act_cand_d_->Compute(gates, gates);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, gates + d_);
|
|
|
|
|
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
|
|
|
|
|
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
|
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d2_, ct);
|
|
|
|
|
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
|
|
|
|
|
}
|
|
|
|
|
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
|
|
|
|
|
/* C_t = igated * cgated*/
|
|
|
|
|
act_gate_d_->Compute(gates + d_, gates + d_);
|
|
|
|
|
act_cand_d_->Compute(gates, gates);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, ct);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, ct, d_);
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_);
|
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -289,36 +289,36 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
|
|
|
|
|
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
|
|
|
|
|
T* checked) const override {
|
|
|
|
|
/* get fgated and igated*/
|
|
|
|
|
vmul_d_->Compute(wp_data, ct_1, checked);
|
|
|
|
|
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_);
|
|
|
|
|
vmul_d_->Compute(wp_data, ct_1, checked, d_);
|
|
|
|
|
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
|
|
|
|
|
vadd_d2_->Compute(checked, gates + d_, gates + d_);
|
|
|
|
|
act_gate_d2_->Compute(gates + d_, gates + d_);
|
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated*/
|
|
|
|
|
act_cand_d_->Compute(gates, gates);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, gates + d_);
|
|
|
|
|
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
|
|
|
|
|
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
|
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d2_, ct);
|
|
|
|
|
/* get ogated*/
|
|
|
|
|
vmul_d_->Compute(wp_data + d2_, ct, gates + d_);
|
|
|
|
|
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
|
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
|
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_);
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
|
|
|
|
|
/* C_t = igated * cgated*/
|
|
|
|
|
act_gate_d_->Compute(gates + d_, gates + d_);
|
|
|
|
|
act_cand_d_->Compute(gates, gates);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, ct);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, ct, d_);
|
|
|
|
|
/* get outgated, put W_oc * C_t on igated */
|
|
|
|
|
vmul_d_->Compute(wp_data + d2_, ct, gates + d_);
|
|
|
|
|
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
|
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_);
|
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -352,8 +352,8 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
|
|
|
|
|
act_cell, d)); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
|
|
|
|
|
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
|
|
|
|
|
REGISTER_JITKERNEL_ARGS_DEPRECATED(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
|
|
|
|
|
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
|
|
|
|
|
|
|
|
|
|
#undef INTRI8_FLOAT
|
|
|
|
|
#undef JITKERNEL_DECLARE_LSTM
|
|
|
|
@ -378,13 +378,13 @@ class GRUKernelImpl : public GRUKernel<T> {
|
|
|
|
|
void ComputeH1(T* gates, T* ht) const override {
|
|
|
|
|
act_gate_d_->Compute(gates, gates);
|
|
|
|
|
act_state_d_->Compute(gates + d2_, gates + d2_);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d2_, ht);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d2_, ht, d_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
|
|
|
|
|
// W: {W_update, W_reset; W_state}
|
|
|
|
|
act_gate_d2_->Compute(gates, gates);
|
|
|
|
|
vmul_d_->Compute(ht_1, gates + d_, ht);
|
|
|
|
|
vmul_d_->Compute(ht_1, gates + d_, ht, d_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
|
|
|
|
@ -472,8 +472,8 @@ INTRI8_FLOAT(jit::avx512f);
|
|
|
|
|
p = std::dynamic_pointer_cast<ker<dtype>>( \
|
|
|
|
|
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d));
|
|
|
|
|
|
|
|
|
|
REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU,
|
|
|
|
|
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
|
|
|
|
|
REGISTER_JITKERNEL_ARGS_DEPRECATED(gru, GRUKernel, JITKERNEL_DECLARE_GRU,
|
|
|
|
|
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
|
|
|
|
|
|
|
|
|
|
#undef INTRI8_FLOAT
|
|
|
|
|
#undef JITKERNEL_NEW_GRU_IMPL
|
|
|
|
|