|
|
|
@ -40,7 +40,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
std::vector<T> ids(batch_size);
|
|
|
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
|
|
|
double r = this->get_rand();
|
|
|
|
|
double r = this->getRandReal();
|
|
|
|
|
int idx = width - 1;
|
|
|
|
|
for (int j = 0; j < width; ++j) {
|
|
|
|
|
if ((r -= ins_vector[i * width + j]) < 0) {
|
|
|
|
@ -60,17 +60,23 @@ class SamplingIdKernel : public framework::OpKernel<T> {
|
|
|
|
|
framework::TensorFromVector(ids, context.device_context(), output);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
double get_rand() const {
|
|
|
|
|
private:
|
|
|
|
|
double getRandReal() const {
|
|
|
|
|
std::call_once(init_flag_, &SamplingIdKernel::getRndInstance);
|
|
|
|
|
return rnd();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void getRndInstance() {
|
|
|
|
|
// Will be used to obtain a seed for the random number engine
|
|
|
|
|
std::random_device rd;
|
|
|
|
|
// Standard mersenne_twister_engine seeded with rd()
|
|
|
|
|
std::mt19937 gen(rd());
|
|
|
|
|
std::uniform_real_distribution<> dis(0, 1);
|
|
|
|
|
return dis(gen);
|
|
|
|
|
rnd = std::bind(dis, gen);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
unsigned int defaultSeed = 0;
|
|
|
|
|
static std::once_flag init_flag_;
|
|
|
|
|
static std::function<> rnd;
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|