|
|
|
|
@ -24,16 +24,20 @@ class SamplingIdOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of SamplingIdOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of SamplingIdOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_LT(ctx->Attrs().Get<float>("min"),
|
|
|
|
|
ctx->Attrs().Get<float>("max"), "min must less then max");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SampleIn");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "X", "SampleOut");
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max"),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"min must less then max, but here min is %f, max is %f",
|
|
|
|
|
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max")));
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE(input_dims.size() == 2,
|
|
|
|
|
"Input(X, Filter) should be 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X, Filter) should be 2-D tensor. But X dim is %d",
|
|
|
|
|
input_dims.size()));
|
|
|
|
|
|
|
|
|
|
auto dim0 = input_dims[0];
|
|
|
|
|
framework::DDim dims = framework::make_ddim({dim0});
|
|
|
|
|
|