|
|
@ -32,7 +32,7 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
|
|
output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
|
|
GPUGather<T>(ctx.GetPlace(), x, index, output);
|
|
|
|
GPUGather<T>(ctx.device_context(), x, index, output);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -42,7 +42,6 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
"This kernel only runs on GPU device.");
|
|
|
|
"This kernel only runs on GPU device.");
|
|
|
|
LOG(INFO) << "Gather grad here";
|
|
|
|
|
|
|
|
auto *Index = ctx.Input<Tensor>("Index");
|
|
|
|
auto *Index = ctx.Input<Tensor>("Index");
|
|
|
|
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
@ -53,7 +52,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
auto place = ctx.GetEigenDevice<platform::GPUPlace>();
|
|
|
|
auto place = ctx.GetEigenDevice<platform::GPUPlace>();
|
|
|
|
dxt.device(place) = dxt.constant(static_cast<T>(0));
|
|
|
|
dxt.device(place) = dxt.constant(static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
|
|
GPUScatterAssign<T>(ctx.GetPlace(), dO, Index, dX);
|
|
|
|
GPUScatterAssign<T>(ctx.device_context(), dO, Index, dX);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|