|
|
@ -13,7 +13,11 @@ See the License for the specific language governing permissions and
|
|
|
|
limitations under the License. */
|
|
|
|
limitations under the License. */
|
|
|
|
#pragma once
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
#include <iostream>
|
|
|
|
|
|
|
|
#include <iterator>
|
|
|
|
#include <random>
|
|
|
|
#include <random>
|
|
|
|
|
|
|
|
#include <sstream>
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
@ -34,17 +38,17 @@ class SamplingIdKernel : public framework::OpKernel<T> {
|
|
|
|
std::vector<T> ins_vector;
|
|
|
|
std::vector<T> ins_vector;
|
|
|
|
framework::TensorToVector(*input, context.device_context(), &ins_vector);
|
|
|
|
framework::TensorToVector(*input, context.device_context(), &ins_vector);
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int> ids(batch_size);
|
|
|
|
std::vector<T> ids(batch_size);
|
|
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
|
|
double r = this->get_rand();
|
|
|
|
double r = this->get_rand();
|
|
|
|
int id = width - 1;
|
|
|
|
int idx = width - 1;
|
|
|
|
for (int j = 0; j < width; ++j) {
|
|
|
|
for (int j = 0; j < width; ++j) {
|
|
|
|
if ((r -= ins_vector[i * width + j]) < 0) {
|
|
|
|
if ((r -= ins_vector[i * width + j]) < 0) {
|
|
|
|
id = j;
|
|
|
|
idx = j;
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ids[i] = id;
|
|
|
|
ids[i] = ins_vector[i * width + idx];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> out_dim;
|
|
|
|
std::vector<int64_t> out_dim;
|
|
|
|