SamplingID Op fix error print (#24521)

* fix error print for sampling_id_op

* fix spell err

* fix spell err test=develop
release/2.0-alpha
Jiawei Wang 6 years ago committed by GitHub
parent 86ca31ab58
commit 4a105f803e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,16 +24,20 @@ class SamplingIdOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SamplingIdOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SamplingIdOp should not be null.");
PADDLE_ENFORCE_LT(ctx->Attrs().Get<float>("min"),
ctx->Attrs().Get<float>("max"), "min must less then max");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SampleIn");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "X", "SampleOut");
PADDLE_ENFORCE_LT(
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max"),
platform::errors::InvalidArgument(
"min must less then max, but here min is %f, max is %f",
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max")));
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor.");
PADDLE_ENFORCE_EQ(
input_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X, Filter) should be 2-D tensor. But X dim is %d",
input_dims.size()));
auto dim0 = input_dims[0];
framework::DDim dims = framework::make_ddim({dim0});

@ -36,9 +36,15 @@ class SamplingIdKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]);
const int width = static_cast<int>(input->dims()[1]);
PADDLE_ENFORCE_GE(batch_size, 0,
"batch_size(dims[0]) must be nonnegative.");
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
PADDLE_ENFORCE_GE(
batch_size, 0,
platform::errors::InvalidArgument(
"batch_size(dims[0]) must be nonnegative. but it is %d.",
batch_size));
PADDLE_ENFORCE_GE(
width, 0,
platform::errors::InvalidArgument(
"width(dims[1]) must be nonnegative. but it is %d.", width));
std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector);

Loading…
Cancel
Save