|
|
@ -181,7 +181,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
act_cand_d_->Compute(gates, gates);
|
|
|
|
act_cand_d_->Compute(gates, gates);
|
|
|
|
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
|
|
|
|
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
|
|
|
|
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
|
|
|
|
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d2_, ct);
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
|
|
|
|
|
|
|
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
@ -291,16 +291,16 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
|
|
|
|
/* get fgated and igated*/
|
|
|
|
/* get fgated and igated*/
|
|
|
|
vmul_d_->Compute(wp_data, ct_1, checked, d_);
|
|
|
|
vmul_d_->Compute(wp_data, ct_1, checked, d_);
|
|
|
|
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
|
|
|
|
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
|
|
|
|
vadd_d2_->Compute(checked, gates + d_, gates + d_);
|
|
|
|
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
|
|
|
|
act_gate_d2_->Compute(gates + d_, gates + d_);
|
|
|
|
act_gate_d2_->Compute(gates + d_, gates + d_);
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated*/
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated*/
|
|
|
|
act_cand_d_->Compute(gates, gates);
|
|
|
|
act_cand_d_->Compute(gates, gates);
|
|
|
|
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
|
|
|
|
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
|
|
|
|
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
|
|
|
|
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d2_, ct);
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
|
|
|
|
/* get ogated*/
|
|
|
|
/* get ogated*/
|
|
|
|
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
|
|
|
|
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_);
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_);
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
@ -314,7 +314,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
|
|
|
|
vmul_d_->Compute(gates, gates + d_, ct, d_);
|
|
|
|
vmul_d_->Compute(gates, gates + d_, ct, d_);
|
|
|
|
/* get outgated, put W_oc * C_t on igated */
|
|
|
|
/* get outgated, put W_oc * C_t on igated */
|
|
|
|
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
|
|
|
|
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_);
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_);
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
|