merge new op grammar

tonyyang-svail-feed-op-desgin
zchen0211 8 years ago
parent 88a8eedda1
commit b851515b16

@ -20,13 +20,8 @@
namespace paddle {
namespace operators {
// template <typename T>
__global__ void print_arr(const float *params, const int N) {
CUDA_1D_KERNEL_LOOP(i, N) { printf("device: %d, %f\n", i, params[i]); }
}
template <typename T>
class GatherOpCUDAKernel : public framework::OpKernel {
class GatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
@ -42,7 +37,7 @@ class GatherOpCUDAKernel : public framework::OpKernel {
};
template <typename T>
class GatherGradOpCUDAKernel : public framework::OpKernel {
class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),

@ -20,7 +20,7 @@ namespace paddle {
namespace operators {
template <typename T>
class ScatterOpCUDAKernel : public framework::OpKernel {
class ScatterOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
@ -37,7 +37,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel {
};
template <typename T>
class ScatterGradOpCUDAKernel : public framework::OpKernel {
class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),

Loading…
Cancel
Save