| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                   T* checked) const override {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    // gates: W_ch, W_ih, W_fh, W_oh
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d3_->ComputeDeprecated(gates + d_, gates + d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d3_->Compute(gates + d_, gates + d_, d3_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    /* C_t = C_t-1 * fgated + cand_gated * igated */
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cand_d_->ComputeDeprecated(gates, gates);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cand_d_->Compute(gates, gates, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    /* H_t = act_cell(C_t) * ogated */
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cell_d_->ComputeDeprecated(ct, gates + d2_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cell_d_->Compute(ct, gates + d2_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    /* C_t = igated * cgated*/
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d_->ComputeDeprecated(gates + d_, gates + d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cand_d_->ComputeDeprecated(gates, gates);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d_->Compute(gates + d_, gates + d_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cand_d_->Compute(gates, gates, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(gates, gates + d_, ct, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    /* H_t = act_cell(C_t) * ogated */
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cell_d_->ComputeDeprecated(ct, gates + d2_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cell_d_->Compute(ct, gates + d2_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(wp_data, ct_1, checked, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d2_->ComputeDeprecated(gates + d_, gates + d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d2_->Compute(gates + d_, gates + d_, d2_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    /* C_t = C_t-1 * fgated + cand_gated * igated*/
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cand_d_->ComputeDeprecated(gates, gates);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cand_d_->Compute(gates, gates, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    /* get ogated*/
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    /* H_t = act_cell(C_t) * ogated */
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cell_d_->ComputeDeprecated(ct, gates + d2_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cell_d_->Compute(ct, gates + d2_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    /* C_t = igated * cgated*/
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d_->ComputeDeprecated(gates + d_, gates + d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cand_d_->ComputeDeprecated(gates, gates);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d_->Compute(gates + d_, gates + d_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cand_d_->Compute(gates, gates, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(gates, gates + d_, ct, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    /* get outgated, put W_oc * C_t on igated */
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    /* H_t = act_cell(C_t) * ogated */
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cell_d_->ComputeDeprecated(ct, gates + d2_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_cell_d_->Compute(ct, gates + d2_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  void ComputeH1(T* gates, T* ht) const override {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d_->ComputeDeprecated(gates, gates);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_state_d_->ComputeDeprecated(gates + d2_, gates + d2_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d_->Compute(gates, gates, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_state_d_->Compute(gates + d2_, gates + d2_, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(gates, gates + d2_, ht, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    // W: {W_update, W_reset; W_state}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d2_->ComputeDeprecated(gates, gates);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_gate_d2_->Compute(gates, gates, d2_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    vmul_d_->Compute(ht_1, gates + d_, ht, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    T* y = gates + d2_;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_state_d_->ComputeDeprecated(y, y);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    act_state_d_->Compute(y, y, d_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    // out = zt*ht~ + (1-zt)*ht_1
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    for (int i = 0; i < d_; ++i) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				 | 
				
					
 
 |