|
|
|
@ -30,6 +30,8 @@ class SamplingIdOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Output(Out) of SamplingIdOp should not be null.");
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE(input_dims.size() == 2,
|
|
|
|
|
"Input(X, Filter) should be 2-D tensor.");
|
|
|
|
|
|
|
|
|
|
framework::DDim dims = input_dims;
|
|
|
|
|
ctx->SetOutputDim("Out", dims);
|
|
|
|
@ -46,10 +48,8 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddOutput("Out", "SamplingId data tensor.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
SamplingId Operator.
|
|
|
|
|
@brief A layer for sampling id from multinomial distribution from the
|
|
|
|
|
input layer. Sampling one id for one sample. The result is stored in
|
|
|
|
|
output_.ids.
|
|
|
|
|
)DOC");
|
|
|
|
|
A layer for sampling id from multinomial distribution from the
|
|
|
|
|
input layer. Sampling one id for one sample.)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|