| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -24,8 +24,9 @@ template <typename T>
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					class GatherOpCUDAKernel : public framework::OpKernel<T> {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					 public:
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  void Compute(const framework::ExecutionContext &ctx) const override {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                   "This kernel only runs on GPU device.");
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      platform::errors::PreconditionNotMet(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                          "This kernel only runs on GPU device."));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto *x = ctx.Input<Tensor>("X");
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto *index = ctx.Input<Tensor>("Index");
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto *output = ctx.Output<Tensor>("Out");
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				 | 
				
					@ -35,12 +36,15 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const auto &index_type = index->type();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    bool index_type_match = index_type == framework::proto::VarType::INT32 ||
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                            index_type == framework::proto::VarType::INT64;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    PADDLE_ENFORCE(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        index_type_match,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        "Index holds the wrong type, it holds %s, but desires to be %s or %s",
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    PADDLE_ENFORCE_EQ(index_type_match, true,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      platform::errors::InvalidArgument(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                          "Index holds the wrong type, it holds [%s],"
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                          "but desires to be [%s] or [%s].",
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                          paddle::framework::DataTypeToString(index_type),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                          paddle::framework::DataTypeToString(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                              framework::proto::VarType::INT32),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                          paddle::framework::DataTypeToString(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                              framework::proto::VarType::INT64)));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    if (index_type == framework::proto::VarType::INT32) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      GPUGather<T, int>(ctx.device_context(), *x, *index, output);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    } else if (index_type == framework::proto::VarType::INT64) {
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				 | 
				
					@ -53,8 +57,9 @@ template <typename T>
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					 public:
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  void Compute(const framework::ExecutionContext &ctx) const override {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                   "This kernel only runs on GPU device.");
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      platform::errors::PreconditionNotMet(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                          "This kernel only runs on GPU device."));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto *index = ctx.Input<Tensor>("Index");
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				 | 
				
					@ -69,12 +74,15 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const auto &index_type = index->type();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    bool index_type_match = index_type == framework::proto::VarType::INT32 ||
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                            index_type == framework::proto::VarType::INT64;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    PADDLE_ENFORCE(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        index_type_match,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        "Index holds the wrong type, it holds %s, but desires to be %s or %s",
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    PADDLE_ENFORCE_EQ(index_type_match, true,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      platform::errors::InvalidArgument(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                          "Index holds the wrong type, it holds [%s],"
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                          "but desires to be [%s] or [%s].",
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                          paddle::framework::DataTypeToString(index_type),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                          paddle::framework::DataTypeToString(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                              framework::proto::VarType::INT32),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                          paddle::framework::DataTypeToString(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                              framework::proto::VarType::INT64)));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    if (index_type == framework::proto::VarType::INT32) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      GPUScatterAssign<T, int>(ctx, *dO, *index, dX,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                               ctx.Attr<bool>("overwrite"));
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				 | 
				
					
 
 |