Fix round in grid sample op (#27657)

my_2.0rc
whs 5 years ago committed by GitHub
parent 3ccee08285
commit daf5aa9b8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -238,9 +238,8 @@ __global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c,
}
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(round(ix));
int iy_nearest = static_cast<int>(round(iy));
int ix_nearest = static_cast<int>(std::nearbyint(ix));
int iy_nearest = static_cast<int>(std::nearbyint(iy));
auto inp_offset_NC = n * inp_sN;
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c;
@ -403,8 +402,8 @@ __global__ void grid_sampler_cuda_backward_kernel(
gGrid_ptr_NHW[1] = giy_mult * giy;
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
int ix_nearest = static_cast<int>(std::nearbyint(ix));
int iy_nearest = static_cast<int>(std::nearbyint(iy));
int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW;
T* gInp_ptr_NC = grad_input + n * inp_sN;

Loading…
Cancel
Save