|
|
|
@ -46,8 +46,10 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
|
|
|
|
|
int in_chw = channels * in_hw;
|
|
|
|
|
int out_chw = channels * out_hw;
|
|
|
|
|
|
|
|
|
|
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
|
|
|
|
|
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
|
|
|
|
|
float ratio_h =
|
|
|
|
|
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
|
|
|
|
|
float ratio_w =
|
|
|
|
|
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
|
|
|
|
|
|
|
|
|
|
if (in_h == out_h && in_w == out_w) {
|
|
|
|
|
memcpy(output, input, input_t->numel() * sizeof(T));
|
|
|
|
@ -56,24 +58,24 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (int i = 0; i < out_h; ++i) { // loop for images
|
|
|
|
|
int h = ratio_h * i;
|
|
|
|
|
int hid = (h < in_h - 1) ? 1 : 0;
|
|
|
|
|
T h1lambda = ratio_h * i - h;
|
|
|
|
|
T h2lambda = 1 - h1lambda;
|
|
|
|
|
float h1lambda = ratio_h * i - h;
|
|
|
|
|
float h2lambda = 1.f - h1lambda;
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < out_w; ++j) {
|
|
|
|
|
int w = ratio_w * j;
|
|
|
|
|
int wid = (w < in_w - 1) ? 1 : 0;
|
|
|
|
|
T w1lambda = ratio_w * j - w;
|
|
|
|
|
T w2lambda = 1 - w1lambda;
|
|
|
|
|
float w1lambda = ratio_w * j - w;
|
|
|
|
|
float w2lambda = 1.f - w1lambda;
|
|
|
|
|
// calculate four position for bilinear interpolation
|
|
|
|
|
const T* in_pos = &input[k * in_chw + h * in_w + w];
|
|
|
|
|
T* out_pos = &output[k * out_chw + i * out_w + j];
|
|
|
|
|
|
|
|
|
|
for (int c = 0; c < channels; ++c) { // loop for channels
|
|
|
|
|
// bilinear interpolation
|
|
|
|
|
out_pos[0] =
|
|
|
|
|
out_pos[0] = static_cast<T>(
|
|
|
|
|
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) +
|
|
|
|
|
h1lambda * (w2lambda * in_pos[hid * in_w] +
|
|
|
|
|
w1lambda * in_pos[hid * in_w + wid]);
|
|
|
|
|
w1lambda * in_pos[hid * in_w + wid]));
|
|
|
|
|
in_pos += in_hw;
|
|
|
|
|
out_pos += out_hw;
|
|
|
|
|
}
|
|
|
|
@ -117,8 +119,10 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
int in_chw = channels * in_hw;
|
|
|
|
|
int out_chw = channels * out_hw;
|
|
|
|
|
|
|
|
|
|
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
|
|
|
|
|
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
|
|
|
|
|
float ratio_h =
|
|
|
|
|
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
|
|
|
|
|
float ratio_w =
|
|
|
|
|
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
|
|
|
|
|
|
|
|
|
|
if (in_h == out_h && in_w == out_w) {
|
|
|
|
|
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
|
|
|
|
@ -127,22 +131,24 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (int i = 0; i < out_h; ++i) { // loop for images
|
|
|
|
|
int h = ratio_h * i;
|
|
|
|
|
int hid = (h < in_h - 1) ? 1 : 0;
|
|
|
|
|
T h1lambda = ratio_h * i - h;
|
|
|
|
|
T h2lambda = 1 - h1lambda;
|
|
|
|
|
float h1lambda = ratio_h * i - h;
|
|
|
|
|
float h2lambda = 1 - h1lambda;
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < out_w; ++j) {
|
|
|
|
|
int w = ratio_w * j;
|
|
|
|
|
int wid = (w < in_w - 1) ? 1 : 0;
|
|
|
|
|
T w1lambda = ratio_w * j - w;
|
|
|
|
|
T w2lambda = 1 - w1lambda;
|
|
|
|
|
float w1lambda = ratio_w * j - w;
|
|
|
|
|
float w2lambda = 1 - w1lambda;
|
|
|
|
|
T* in_pos = &d_input[k * in_chw + h * in_w + w];
|
|
|
|
|
const T* out_pos = &d_output[k * out_chw + i * out_w + j];
|
|
|
|
|
|
|
|
|
|
for (int c = 0; c < channels; ++c) { // loop for channels
|
|
|
|
|
in_pos[0] += h2lambda * w2lambda * out_pos[0];
|
|
|
|
|
in_pos[wid] += h2lambda * w1lambda * out_pos[0];
|
|
|
|
|
in_pos[hid * in_w] += h1lambda * w2lambda * out_pos[0];
|
|
|
|
|
in_pos[hid * in_w + wid] += h1lambda * w1lambda * out_pos[0];
|
|
|
|
|
in_pos[0] += static_cast<T>(h2lambda * w2lambda * out_pos[0]);
|
|
|
|
|
in_pos[wid] += static_cast<T>(h2lambda * w1lambda * out_pos[0]);
|
|
|
|
|
in_pos[hid * in_w] +=
|
|
|
|
|
static_cast<T>(h1lambda * w2lambda * out_pos[0]);
|
|
|
|
|
in_pos[hid * in_w + wid] +=
|
|
|
|
|
static_cast<T>(h1lambda * w1lambda * out_pos[0]);
|
|
|
|
|
in_pos += in_hw;
|
|
|
|
|
out_pos += out_hw;
|
|
|
|
|
}
|
|
|
|
|