|
|
|
@ -46,6 +46,10 @@ class SamplingIdGPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int batch_size = static_cast<int>(input->dims()[0]);
|
|
|
|
|
const int width = static_cast<int>(input->dims()[1]);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(batch_size, 0,
|
|
|
|
|
"batch_size(dims[0]) must be nonnegative.");
|
|
|
|
|
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
|
|
|
|
|
|
|
|
|
|
std::vector<T> ins_vector;
|
|
|
|
|
framework::TensorToVector(*input, context.device_context(), &ins_vector);
|
|
|
|
|
|
|
|
|
@ -56,10 +60,11 @@ class SamplingIdGPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
T min = static_cast<T>(context.Attr<float>("min"));
|
|
|
|
|
T max = static_cast<T>(context.Attr<float>("max"));
|
|
|
|
|
UniformGenerator<T> gen = UniformGenerator<T>(min, max, seed);
|
|
|
|
|
|
|
|
|
|
std::vector<T> ids(batch_size);
|
|
|
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
|
|
|
T r = UniformGenerator<T>(min, max, seed);
|
|
|
|
|
T r = gen(0);
|
|
|
|
|
int idx = width - 1;
|
|
|
|
|
for (int j = 0; j < width; ++j) {
|
|
|
|
|
if ((r -= ins_vector[i * width + j]) < 0) {
|
|
|
|
|