|
|
|
@ -118,9 +118,10 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
"bilinear" == interp_method || "nearest" == interp_method ||
|
|
|
|
|
"bicubic" == interp_method,
|
|
|
|
|
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
|
|
|
|
|
"Input(X) dimension is 4, but got method = %s .",
|
|
|
|
|
interp_method);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
|
|
|
|
|
"Input(X) dimension is 4, but got method = %s.",
|
|
|
|
|
interp_method));
|
|
|
|
|
const DataLayout data_layout = framework::StringToDataLayout(
|
|
|
|
|
ctx->Attrs().Get<std::string>("data_layout"));
|
|
|
|
|
|
|
|
|
@ -305,12 +306,15 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
|
|
|
|
|
auto out_size_dim = ctx->GetInputDim("OutSize");
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
|
|
|
|
|
"OutSize's dimension size must be 1, but got size =%d .",
|
|
|
|
|
out_size_dim.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_size_dim.size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"OutSize's dimension size must be 1, but got size is %d.",
|
|
|
|
|
out_size_dim.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_size_dim[0], 3,
|
|
|
|
|
"OutSize's dim[0] must be 3, but got size = %d .",
|
|
|
|
|
out_size_dim[0]);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"OutSize's dim[0] must be 3, but got size is %d.",
|
|
|
|
|
out_size_dim[0]));
|
|
|
|
|
ctx->ShareLoD("X", "Out");
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -330,10 +334,8 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of InterpolateV2Op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of InterpolationOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Interpolate");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Interpolate");
|
|
|
|
|
|
|
|
|
|
auto dim_x = ctx->GetInputDim("X"); // NCHW format
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
@ -576,9 +578,10 @@ class InterpolateV2OpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InterpolateGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
"Out@GRAD", "InterpolateGrad");
|
|
|
|
|
|
|
|
|
|
auto dim_x = ctx->GetInputDim("X");
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
|
|
|
|
|