| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -49,12 +49,50 @@ void VTanh(const T* x, T* y, int n) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					void Softmax(const T* x, T* y, int n, int bs) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto compute_vexp =
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  typename XRNTuples<T>::func_type compute_hmax{nullptr};
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  typename XRNTuples<T>::func_type compute_hsum{nullptr};
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  typename AXYNTuples<T>::func_type compute_vscal{nullptr};
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  typename AXYNTuples<T>::func_type compute_vaddbias{nullptr};
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  typename XYNTuples<T>::func_type compute_vexp{nullptr};
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (!KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Has(n)) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Insert(n, compute_hmax);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  } else {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    compute_hmax = KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().At(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (!KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Has(n)) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Insert(n, compute_hsum);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  } else {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    compute_hsum = KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().At(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (!KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Has(n)) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Insert(n,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                                               compute_vscal);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  } else {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    compute_vscal = KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().At(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (!KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Has(n)) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Insert(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        n, compute_vaddbias);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  } else {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    compute_vaddbias =
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().At(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (!KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Has(n)) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    compute_vexp = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Insert(n, compute_vexp);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  } else {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    compute_vexp = KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().At(n);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  for (int i = 0; i < bs; ++i) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    T scalar;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    compute_hmax(x, &scalar, n);
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				 | 
				
					
 
 |