|
|
|
@ -15,30 +15,31 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <random>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class SamplingIdKernel : public framework::OpKernel<T> {
|
|
|
|
|
/// Produces random floating-point values, uniformly distributed on [0, 1).
|
|
|
|
|
std::uniform_real_distribution<double> rand1_;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
const Tensor* input = context.Input<Tensor>("X");
|
|
|
|
|
const int batch_size = static_cast<int>(input->dims()[0]);
|
|
|
|
|
const int width = static_cast<int>(input->dims()[1]);
|
|
|
|
|
|
|
|
|
|
std::vector<int> ids(batchSize);
|
|
|
|
|
auto& reng = get();
|
|
|
|
|
std::vector<T> ins_vector;
|
|
|
|
|
framework::TensorToVector(*input, context.device_context(), &ins_vector);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < batchSize; ++i) {
|
|
|
|
|
double r = rand1_(reng);
|
|
|
|
|
int id = dim - 1;
|
|
|
|
|
for (int j = 0; j < dim; ++j) {
|
|
|
|
|
if ((r -= buf[i * dim + j]) < 0) {
|
|
|
|
|
std::vector<int> ids(batch_size);
|
|
|
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
|
|
|
double r = this->get_rand();
|
|
|
|
|
int id = width - 1;
|
|
|
|
|
for (int j = 0; j < width; ++j) {
|
|
|
|
|
if ((r -= ins_vector[i * width + j]) < 0) {
|
|
|
|
|
id = j;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
@ -50,19 +51,22 @@ class SamplingIdKernel : public framework::OpKernel<T> {
|
|
|
|
|
out_dim.push_back(static_cast<int64_t>(batch_size));
|
|
|
|
|
|
|
|
|
|
Tensor* output = context.Output<Tensor>("Output");
|
|
|
|
|
output->Resize(framework::make_ddim(in_dim));
|
|
|
|
|
output->Resize(framework::make_ddim(out_dim));
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
framework::TensorFromVector(ids, context.device_context(), output);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::default_random_engine& get() {
|
|
|
|
|
auto engine = new std::default_random_engine;
|
|
|
|
|
engine->seed(defaultSeed);
|
|
|
|
|
return *engine;
|
|
|
|
|
double get_rand() const {
|
|
|
|
|
// Will be used to obtain a seed for the random number engine
|
|
|
|
|
std::random_device rd;
|
|
|
|
|
// Standard mersenne_twister_engine seeded with rd()
|
|
|
|
|
std::mt19937 gen(rd());
|
|
|
|
|
std::uniform_real_distribution<> dis(0, 1);
|
|
|
|
|
return dis(gen);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
unsigned int defaultSeed = 0;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|