|
|
|
@ -174,8 +174,8 @@ class InterpolateKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
float scale = ctx.Attr<float>("scale");
|
|
|
|
|
if (scale > 0) {
|
|
|
|
|
out_h = in_h * scale;
|
|
|
|
|
out_w = in_w * scale;
|
|
|
|
|
out_h = static_cast<int>(in_h * scale);
|
|
|
|
|
out_w = static_cast<int>(in_w * scale);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto out_size = ctx.Input<Tensor>("OutSize");
|
|
|
|
@ -239,8 +239,8 @@ class InterpolateGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
float scale = ctx.Attr<float>("scale");
|
|
|
|
|
if (scale > 0) {
|
|
|
|
|
out_h = in_h * scale;
|
|
|
|
|
out_w = in_w * scale;
|
|
|
|
|
out_h = static_cast<int>(in_h * scale);
|
|
|
|
|
out_w = static_cast<int>(in_w * scale);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto out_size = ctx.Input<Tensor>("OutSize");
|
|
|
|
|