|
|
|
@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
|
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
|
|
|
|
|
T* checked) const override {
|
|
|
|
|
// gates: W_ch, W_ih, W_fh, W_oh
|
|
|
|
|
act_gate_d3_->ComputeDeprecated(gates + d_, gates + d_);
|
|
|
|
|
act_gate_d3_->Compute(gates + d_, gates + d_, d3_);
|
|
|
|
|
|
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated */
|
|
|
|
|
act_cand_d_->ComputeDeprecated(gates, gates);
|
|
|
|
|
act_cand_d_->Compute(gates, gates, d_);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
|
|
|
|
|
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
|
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
|
|
|
|
|
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
|
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_, d_);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
|
|
|
|
|
}
|
|
|
|
|
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
|
|
|
|
|
/* C_t = igated * cgated*/
|
|
|
|
|
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_);
|
|
|
|
|
act_cand_d_->ComputeDeprecated(gates, gates);
|
|
|
|
|
act_gate_d_->Compute(gates + d_, gates + d_, d_);
|
|
|
|
|
act_cand_d_->Compute(gates, gates, d_);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, ct, d_);
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
|
|
|
|
|
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
|
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
|
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_, d_);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
|
|
|
|
|
vmul_d_->Compute(wp_data, ct_1, checked, d_);
|
|
|
|
|
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
|
|
|
|
|
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
|
|
|
|
|
act_gate_d2_->ComputeDeprecated(gates + d_, gates + d_);
|
|
|
|
|
act_gate_d2_->Compute(gates + d_, gates + d_, d2_);
|
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated*/
|
|
|
|
|
act_cand_d_->ComputeDeprecated(gates, gates);
|
|
|
|
|
act_cand_d_->Compute(gates, gates, d_);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
|
|
|
|
|
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
|
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
|
|
|
|
|
/* get ogated*/
|
|
|
|
|
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
|
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
|
|
|
|
|
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
|
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
|
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_, d_);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
|
|
|
|
|
/* C_t = igated * cgated*/
|
|
|
|
|
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_);
|
|
|
|
|
act_cand_d_->ComputeDeprecated(gates, gates);
|
|
|
|
|
act_gate_d_->Compute(gates + d_, gates + d_, d_);
|
|
|
|
|
act_cand_d_->Compute(gates, gates, d_);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, ct, d_);
|
|
|
|
|
/* get outgated, put W_oc * C_t on igated */
|
|
|
|
|
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
|
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
|
|
|
|
|
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
|
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
|
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_, d_);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeH1(T* gates, T* ht) const override {
|
|
|
|
|
act_gate_d_->ComputeDeprecated(gates, gates);
|
|
|
|
|
act_state_d_->ComputeDeprecated(gates + d2_, gates + d2_);
|
|
|
|
|
act_gate_d_->Compute(gates, gates, d_);
|
|
|
|
|
act_state_d_->Compute(gates + d2_, gates + d2_, d_);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d2_, ht, d_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
|
|
|
|
|
// W: {W_update, W_reset; W_state}
|
|
|
|
|
act_gate_d2_->ComputeDeprecated(gates, gates);
|
|
|
|
|
act_gate_d2_->Compute(gates, gates, d2_);
|
|
|
|
|
vmul_d_->Compute(ht_1, gates + d_, ht, d_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
|
|
|
|
|
T* y = gates + d2_;
|
|
|
|
|
act_state_d_->ComputeDeprecated(y, y);
|
|
|
|
|
act_state_d_->Compute(y, y, d_);
|
|
|
|
|
// out = zt*ht~ + (1-zt)*ht_1
|
|
|
|
|
for (int i = 0; i < d_; ++i) {
|
|
|
|
|
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
|
|
|
|
|