|  |  |  | @ -54,13 +54,12 @@ class KLDivLossKernel : public framework::OpKernel<T> { | 
			
		
	
		
			
				
					|  |  |  |  |     auto input_t = EigenVector<T>::Flatten(*input); | 
			
		
	
		
			
				
					|  |  |  |  |     auto target_t = EigenVector<T>::Flatten(*target); | 
			
		
	
		
			
				
					|  |  |  |  |     auto loss_t = EigenVector<T>::Flatten(*loss); | 
			
		
	
		
			
				
					|  |  |  |  |     // auto target_mask = (target_t > target_t.constant(0)).template cast<T>();
 | 
			
		
	
		
			
				
					|  |  |  |  |     // auto output = (target_t * (target_t.log() - input_t)) * target_mask;
 | 
			
		
	
		
			
				
					|  |  |  |  |     auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>()); | 
			
		
	
		
			
				
					|  |  |  |  |     if ("none" == reduction) { | 
			
		
	
		
			
				
					|  |  |  |  |       loss_t.device(place) = output; | 
			
		
	
		
			
				
					|  |  |  |  |     } else if ("batchmean" == reduction) { | 
			
		
	
		
			
				
					|  |  |  |  |       loss_t.device(place) = output.sum() / static_cast<T>(n); | 
			
		
	
		
			
				
					|  |  |  |  |       auto output_sum = output.sum().eval(); | 
			
		
	
		
			
				
					|  |  |  |  |       loss_t.device(place) = output_sum / output_sum.constant(n); | 
			
		
	
		
			
				
					|  |  |  |  |     } else if ("mean" == reduction) { | 
			
		
	
		
			
				
					|  |  |  |  |       loss_t.device(place) = output.mean(); | 
			
		
	
		
			
				
					|  |  |  |  |     } else if ("sum" == reduction) { | 
			
		
	
	
		
			
				
					|  |  |  | @ -74,19 +73,17 @@ class KLDivLossGradKernel : public framework::OpKernel<T> { | 
			
		
	
		
			
				
					|  |  |  |  |  public: | 
			
		
	
		
			
				
					|  |  |  |  |   void Compute(const framework::ExecutionContext& ctx) const override { | 
			
		
	
		
			
				
					|  |  |  |  |     auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); | 
			
		
	
		
			
				
					|  |  |  |  |     auto* input = ctx.Input<Tensor>("X"); | 
			
		
	
		
			
				
					|  |  |  |  |     auto* target = ctx.Input<Tensor>("Target"); | 
			
		
	
		
			
				
					|  |  |  |  |     auto reduction = ctx.Attr<std::string>("reduction"); | 
			
		
	
		
			
				
					|  |  |  |  |     auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X")); | 
			
		
	
		
			
				
					|  |  |  |  |     auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss")); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     const int n = input->dims()[0]; | 
			
		
	
		
			
				
					|  |  |  |  |     const int numel = input->numel(); | 
			
		
	
		
			
				
					|  |  |  |  |     const int n = input_grad->dims()[0]; | 
			
		
	
		
			
				
					|  |  |  |  |     const int numel = input_grad->numel(); | 
			
		
	
		
			
				
					|  |  |  |  |     const int expand = numel / loss_grad->numel(); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     input_grad->mutable_data<T>(ctx.GetPlace()); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     auto input_t = EigenVector<T>::Flatten(*input); | 
			
		
	
		
			
				
					|  |  |  |  |     auto target_t = EigenVector<T>::Flatten(*target); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     auto input_grad_t = EigenVector<T>::Flatten(*input_grad); | 
			
		
	
	
		
			
				
					|  |  |  | @ -96,14 +93,6 @@ class KLDivLossGradKernel : public framework::OpKernel<T> { | 
			
		
	
		
			
				
					|  |  |  |  |     auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand)); | 
			
		
	
		
			
				
					|  |  |  |  |     input_grad_t.device(place) = | 
			
		
	
		
			
				
					|  |  |  |  |         target_t * target_t.constant(-1.0) * loss_grad_expand * target_mask; | 
			
		
	
		
			
				
					|  |  |  |  |     // if (reduction == "none") {
 | 
			
		
	
		
			
				
					|  |  |  |  |     //   input_grad_t.device(place) =
 | 
			
		
	
		
			
				
					|  |  |  |  |     //       target_t * loss_grad_t * target_t.constant(-1.0);
 | 
			
		
	
		
			
				
					|  |  |  |  |     // } else {
 | 
			
		
	
		
			
				
					|  |  |  |  |     //   auto loss_grad_expand = loss_grad_t.broadcast(Array1(numel));
 | 
			
		
	
		
			
				
					|  |  |  |  |     //   input_grad_t.device(place) =
 | 
			
		
	
		
			
				
					|  |  |  |  |     //       target_t * loss_grad_expand * target_t.constant(-1.0);
 | 
			
		
	
		
			
				
					|  |  |  |  |     // }
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     if ("mean" == reduction) { | 
			
		
	
		
			
				
					|  |  |  |  |       input_grad_t.device(place) = input_grad_t / static_cast<T>(numel); | 
			
		
	
	
		
			
				
					|  |  |  | 
 |