round down for scale. test=develop

revert-16734-refine/test_imperative_transformer
dengkaipeng 6 years ago
parent 2078f4207f
commit 0f7411a1ae

@ -41,8 +41,9 @@ class InterpolateOp : public framework::OperatorWithKernel {
int out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
out_h = dim_x[2] * scale;
out_w = dim_x[3] * scale;
// round down
out_h = static_cast<int>(dim_x[2] * scale);
out_w = static_cast<int>(dim_x[3] * scale);
} else {
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");

@ -174,8 +174,8 @@ class InterpolateKernel : public framework::OpKernel<T> {
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = in_h * scale;
out_w = in_w * scale;
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
@ -239,8 +239,8 @@ class InterpolateGradKernel : public framework::OpKernel<T> {
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = in_h * scale;
out_w = in_w * scale;
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");

Loading…
Cancel
Save