|
|
|
@ -45,9 +45,14 @@ class InterpolateOp : public framework::OperatorWithKernel {
|
|
|
|
|
// round down
|
|
|
|
|
out_h = static_cast<int>(dim_x[2] * scale);
|
|
|
|
|
out_w = static_cast<int>(dim_x[3] * scale);
|
|
|
|
|
// protect when input shape is -1
|
|
|
|
|
out_h = out_h > 0 ? out_h : -1;
|
|
|
|
|
out_w = out_w > 0 ? out_w : -1;
|
|
|
|
|
} else {
|
|
|
|
|
out_h = ctx->Attrs().Get<int>("out_h");
|
|
|
|
|
out_w = ctx->Attrs().Get<int>("out_w");
|
|
|
|
|
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
|
|
|
|
|
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
|
|
|
|
@ -59,12 +64,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime() || (out_h > 0 && out_w > 0)) {
|
|
|
|
|
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
|
|
|
|
|
} else {
|
|
|
|
|
ctx->SetOutputDim("Out", dim_x);
|
|
|
|
|
}
|
|
|
|
|
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|