|
|
|
@ -39,7 +39,7 @@ namespace operators {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class SamplingIdKernel : public framework::OpKernel<T> {
|
|
|
|
|
class SamplingIdGPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
const Tensor* input = context.Input<Tensor>("X");
|
|
|
|
@ -83,5 +83,6 @@ class SamplingIdKernel : public framework::OpKernel<T> {
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel<float>,
|
|
|
|
|
paddle::operators::SamplingIdKernel<double>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(sampling_id,
|
|
|
|
|
paddle::operators::SamplingIdGPUKernel<float>,
|
|
|
|
|
paddle::operators::SamplingIdGPUKernel<double>);
|
|
|
|
|