|
|
|
@ -36,9 +36,19 @@ class SamplingIdKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::vector<T> ins_vector;
|
|
|
|
|
framework::TensorToVector(*input, context.device_context(), &ins_vector);
|
|
|
|
|
|
|
|
|
|
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
|
|
|
|
|
std::minstd_rand engine;
|
|
|
|
|
if (seed == 0) {
|
|
|
|
|
seed = std::random_device()();
|
|
|
|
|
}
|
|
|
|
|
engine.seed(seed);
|
|
|
|
|
std::uniform_real_distribution<T> dist(
|
|
|
|
|
static_cast<T>(ctx.Attr<float>("min")),
|
|
|
|
|
static_cast<T>(ctx.Attr<float>("max")));
|
|
|
|
|
|
|
|
|
|
std::vector<T> ids(batch_size);
|
|
|
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
|
|
|
double r = getRandReal();
|
|
|
|
|
double r = dist(engine);
|
|
|
|
|
int idx = width - 1;
|
|
|
|
|
for (int j = 0; j < width; ++j) {
|
|
|
|
|
if ((r -= ins_vector[i * width + j]) < 0) {
|
|
|
|
@ -57,16 +67,6 @@ class SamplingIdKernel : public framework::OpKernel<T> {
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
framework::TensorFromVector(ids, context.device_context(), output);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
double getRandReal() const {
|
|
|
|
|
std::random_device
|
|
|
|
|
rd; // Will be used to obtain a seed for the random number engine
|
|
|
|
|
std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with
|
|
|
|
|
// rd()
|
|
|
|
|
std::uniform_real_distribution<> dis(1.0, 2.0);
|
|
|
|
|
return dis(gen);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SamplingIdOp : public framework::OperatorWithKernel {
|
|
|
|
@ -78,6 +78,9 @@ class SamplingIdOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Input(X) of SamplingIdOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of SamplingIdOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
|
|
|
|
|
"min must less then max");
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE(input_dims.size() == 2,
|
|
|
|
@ -99,7 +102,17 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
SamplingId Operator.
|
|
|
|
|
A layer for sampling id from multinomial distribution from the
|
|
|
|
|
input layer. Sampling one id for one sample.)DOC");
|
|
|
|
|
input. Sampling one id for one sample.)DOC");
|
|
|
|
|
AddAttr<float>("min", "Minimum value of random. [default 0.0].")
|
|
|
|
|
.SetDefault(0.0f);
|
|
|
|
|
AddAttr<float>("max", "Maximun value of random. [default 1.0].")
|
|
|
|
|
.SetDefault(1.0f);
|
|
|
|
|
AddAttr<int>("seed",
|
|
|
|
|
"Random seed used for the random number engine. "
|
|
|
|
|
"0 means use a seed generated by the system."
|
|
|
|
|
"Note that if seed is not 0, this operator will always "
|
|
|
|
|
"generate the same random numbers every time. [default 0].")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
@ -109,8 +122,5 @@ namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
sampling_id, ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int64_t>,
|
|
|
|
|
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel<float>,
|
|
|
|
|
paddle::operators::SamplingIdKernel<double>);
|
|
|
|
|