You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							195 lines
						
					
					
						
							7.4 KiB
						
					
					
				
			
		
		
	
	
							195 lines
						
					
					
						
							7.4 KiB
						
					
					
				/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License. */
 | 
						|
 | 
						|
#include <cublas.h>
 | 
						|
#include <string>
 | 
						|
#include "paddle/fluid/framework/eigen.h"
 | 
						|
#include "paddle/fluid/operators/batch_fc_op.h"
 | 
						|
#include "paddle/fluid/operators/math/blas.h"
 | 
						|
#include "paddle/fluid/platform/cuda_primitives.h"
 | 
						|
#include "paddle/fluid/platform/gpu_info.h"
 | 
						|
 | 
						|
namespace paddle {
 | 
						|
namespace operators {
 | 
						|
using framework::Tensor;
 | 
						|
 | 
						|
const int CUDA_NUM_THREADS = 1024;
 | 
						|
static inline int GET_BLOCKS(const int N) {
 | 
						|
  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
 | 
						|
}
 | 
						|
 | 
						|
template <typename T>
 | 
						|
__global__ void add_bias_kernel(T* data, int slot_pairs_num, int ins_num,
 | 
						|
                                int out_dim, const T* bias) {
 | 
						|
  CUDA_KERNEL_LOOP(idx, slot_pairs_num * ins_num * out_dim) {
 | 
						|
    int block_len = ins_num * out_dim;
 | 
						|
    int slot_index = idx / block_len;
 | 
						|
    int out_dim_index = (idx % block_len) % out_dim;
 | 
						|
    T temp = data[idx] + bias[slot_index * out_dim + out_dim_index];
 | 
						|
    data[idx] = temp;
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
template <typename T>
 | 
						|
void add_bias(cudaStream_t stream, T* data, int slot_pairs_num, int ins_num,
 | 
						|
              int out_dim, const T* bias) {
 | 
						|
  add_bias_kernel<<<GET_BLOCKS(slot_pairs_num * ins_num * out_dim),
 | 
						|
                    CUDA_NUM_THREADS, 0, stream>>>(data, slot_pairs_num,
 | 
						|
                                                   ins_num, out_dim, bias);
 | 
						|
}
 | 
						|
 | 
						|
template <typename T>
 | 
						|
__global__ void add_bias_grad_kernel(const T* dout_data, int slot_pairs_num,
 | 
						|
                                     int ins_num, int out_dim, T* db_data) {
 | 
						|
  CUDA_KERNEL_LOOP(idx, slot_pairs_num * out_dim) {
 | 
						|
    int row = idx / out_dim;
 | 
						|
    int col = idx % out_dim;
 | 
						|
    T temp = static_cast<T>(0);
 | 
						|
    for (int i = 0; i < ins_num; ++i) {
 | 
						|
      int select_indx = ((row + 1) * i + 1) * col;
 | 
						|
      temp += dout_data[select_indx];
 | 
						|
    }
 | 
						|
    db_data[idx] += temp;
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
template <typename T>
 | 
						|
void add_bias_grad(cudaStream_t stream, const T* dout_data, int slot_pairs_num,
 | 
						|
                   int ins_num, int out_dim, T* db_data) {
 | 
						|
  add_bias_grad_kernel<<<GET_BLOCKS(slot_pairs_num * out_dim), CUDA_NUM_THREADS,
 | 
						|
                         0, stream>>>(dout_data, slot_pairs_num, ins_num,
 | 
						|
                                      out_dim, db_data);
 | 
						|
}
 | 
						|
 | 
						|
template <typename DeviceContext, typename T>
 | 
						|
class BatchFCCUDAKernel : public framework::OpKernel<T> {
 | 
						|
 public:
 | 
						|
  void Compute(const framework::ExecutionContext& ctx) const override {
 | 
						|
    // X.dim = slot_pairs_num * ins_num * in_dim
 | 
						|
    // W.dim = slot_pairs_num * in_dim * out_dim
 | 
						|
    // b.dim = slot_pairs_num * out_dim
 | 
						|
    // output.dim = slot_pairs_num * ins_num * out_dim
 | 
						|
    auto* input = ctx.Input<framework::LoDTensor>("Input");
 | 
						|
    auto* w = ctx.Input<Tensor>("W");
 | 
						|
    auto* bias = ctx.Input<Tensor>("Bias");
 | 
						|
    auto* output = ctx.Output<framework::LoDTensor>("Out");
 | 
						|
    auto input_dims = input->dims();
 | 
						|
    auto w_dims = w->dims();
 | 
						|
    auto slot_pairs_num = input_dims[0];
 | 
						|
    auto ins_num = input_dims[1];
 | 
						|
    auto in_dim = input_dims[2];
 | 
						|
    auto out_dim = w_dims[2];
 | 
						|
 | 
						|
    // get data ptr
 | 
						|
    const T* in_data = input->data<T>();
 | 
						|
    const T* w_data = w->data<T>();
 | 
						|
    const T* bias_data = bias->data<T>();
 | 
						|
 | 
						|
    output->Resize({slot_pairs_num, ins_num, out_dim});
 | 
						|
    T* out_data = output->mutable_data<T>(ctx.GetPlace());
 | 
						|
    // initialize
 | 
						|
    auto out_eigen = framework::EigenVector<T>::Flatten(*output);
 | 
						|
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
 | 
						|
    auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
 | 
						|
                       .eigen_device();
 | 
						|
    out_eigen.device(place) = out_eigen.constant(static_cast<T>(0));
 | 
						|
 | 
						|
    CBLAS_TRANSPOSE transA = CblasNoTrans;
 | 
						|
    CBLAS_TRANSPOSE transB = CblasNoTrans;
 | 
						|
 | 
						|
    T alpha = 1;
 | 
						|
    T beta = 0;
 | 
						|
    int64_t strideA = ins_num * in_dim;
 | 
						|
    int64_t strideB = in_dim * out_dim;
 | 
						|
 | 
						|
    auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
 | 
						|
    blas.BatchedGEMM(transA, transB, ins_num, out_dim, in_dim, alpha, in_data,
 | 
						|
                     w_data, beta, out_data, slot_pairs_num, strideA, strideB);
 | 
						|
    add_bias<T>(ctx.cuda_device_context().stream(), out_data, slot_pairs_num,
 | 
						|
                ins_num, out_dim, bias_data);
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
template <typename DeviceContext, typename T>
 | 
						|
class BatchFCGradOpCUDAKernel : public framework::OpKernel<T> {
 | 
						|
 public:
 | 
						|
  void Compute(const framework::ExecutionContext& ctx) const override {
 | 
						|
    auto* input = ctx.Input<Tensor>("Input");
 | 
						|
    auto* w = ctx.Input<Tensor>("W");
 | 
						|
    auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
 | 
						|
 | 
						|
    auto* dx = ctx.Output<Tensor>(framework::GradVarName("Input"));
 | 
						|
    auto* dw = ctx.Output<Tensor>(framework::GradVarName("W"));
 | 
						|
    auto* db = ctx.Output<Tensor>(framework::GradVarName("Bias"));
 | 
						|
 | 
						|
    auto input_dims = input->dims();
 | 
						|
    auto w_dims = w->dims();
 | 
						|
    auto slot_pairs_num = input_dims[0];
 | 
						|
    auto ins_num = input_dims[1];
 | 
						|
    auto in_dim = input_dims[2];
 | 
						|
    auto out_dim = w_dims[2];
 | 
						|
 | 
						|
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
 | 
						|
    auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
 | 
						|
                       .eigen_device();
 | 
						|
    // initialize
 | 
						|
    dx->mutable_data<T>(ctx.GetPlace());
 | 
						|
    auto dx_eigen = framework::EigenVector<T>::Flatten(*dx);
 | 
						|
    dx_eigen.device(place) = dx_eigen.constant(static_cast<T>(0));
 | 
						|
 | 
						|
    dw->mutable_data<T>(ctx.GetPlace());
 | 
						|
    auto dw_eigen = framework::EigenVector<T>::Flatten(*dw);
 | 
						|
    dw_eigen.device(place) = dw_eigen.constant(static_cast<T>(0));
 | 
						|
 | 
						|
    // get data ptr
 | 
						|
    const T* x_data = input->data<T>();
 | 
						|
    const T* w_data = w->data<T>();
 | 
						|
    const T* dout_data = dout->data<T>();
 | 
						|
    T* dx_data = dx->data<T>();
 | 
						|
    T* dw_data = dw->data<T>();
 | 
						|
 | 
						|
    db->mutable_data<T>(ctx.GetPlace());
 | 
						|
    auto db_eigen = framework::EigenVector<T>::Flatten(*db);
 | 
						|
    db_eigen.device(place) = db_eigen.constant(static_cast<T>(0));
 | 
						|
    T* db_data = db->data<T>();
 | 
						|
    add_bias_grad<T>(ctx.cuda_device_context().stream(), dout_data,
 | 
						|
                     slot_pairs_num, ins_num, out_dim, db_data);
 | 
						|
 | 
						|
    auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
 | 
						|
    T alpha = 1;
 | 
						|
    T beta = 0;
 | 
						|
 | 
						|
    // dx = dout_data * y^T
 | 
						|
    blas.BatchedGEMM(CblasNoTrans, CblasTrans, ins_num, in_dim, out_dim, alpha,
 | 
						|
                     dout_data, w_data, beta, dx_data, slot_pairs_num,
 | 
						|
                     ins_num * out_dim, out_dim * in_dim);
 | 
						|
    // dy = x^T * dout_data
 | 
						|
    blas.BatchedGEMM(CblasTrans, CblasNoTrans, in_dim, out_dim, ins_num, alpha,
 | 
						|
                     x_data, dout_data, beta, dw_data, slot_pairs_num,
 | 
						|
                     in_dim * ins_num, ins_num * out_dim);
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
}  // namespace operators
 | 
						|
}  // namespace paddle
 | 
						|
 | 
						|
namespace ops = paddle::operators;
 | 
						|
using GPUCtx = paddle::platform::CUDADeviceContext;
 | 
						|
REGISTER_OP_CUDA_KERNEL(batch_fc, ops::BatchFCCUDAKernel<GPUCtx, float>,
 | 
						|
                        ops::BatchFCCUDAKernel<GPUCtx, double>);
 | 
						|
 | 
						|
REGISTER_OP_CUDA_KERNEL(batch_fc_grad,
 | 
						|
                        ops::BatchFCGradOpCUDAKernel<GPUCtx, float>,
 | 
						|
                        ops::BatchFCGradOpCUDAKernel<GPUCtx, double>);
 |