|
|
|
@ -20,13 +20,8 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
// template <typename T>
|
|
|
|
|
__global__ void print_arr(const float *params, const int N) {
|
|
|
|
|
CUDA_1D_KERNEL_LOOP(i, N) { printf("device: %d, %f\n", i, params[i]); }
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class GatherOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
class GatherOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
@ -42,7 +37,7 @@ class GatherOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class GatherGradOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
|