|
|
|
@ -53,8 +53,13 @@ class SamplingIdKernel : public framework::OpKernel<T> {
|
|
|
|
|
static_cast<T>(context.Attr<float>("min")),
|
|
|
|
|
static_cast<T>(context.Attr<float>("max")));
|
|
|
|
|
|
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
std::vector<int64_t> ids(batch_size);
|
|
|
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
|
|
|
=======
|
|
|
|
|
std::vector<T> ids(batch_size);
|
|
|
|
|
for (int i = 0; i < batch_size; ++i) {
|
|
|
|
|
>>>>>>> 823c4f87beff04e4029e3f4a183658621ca8f01b
|
|
|
|
|
T r = dist(engine);
|
|
|
|
|
int idx = width - 1;
|
|
|
|
|
for (int j = 0; j < width; ++j) {
|
|
|
|
@ -63,7 +68,11 @@ class SamplingIdKernel : public framework::OpKernel<T> {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
ids[i] = int64_t(idx);
|
|
|
|
|
=======
|
|
|
|
|
ids[i] = ins_vector[idx];
|
|
|
|
|
>>>>>>> 823c4f87beff04e4029e3f4a183658621ca8f01b
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> out_dim;
|
|
|
|
|