commit
72ee737f3f
@ -0,0 +1,94 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/bilinear_interp_op.h"
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class BilinearInterpOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of BilinearInterOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of BilinearInterOp should not be null.");
|
||||
|
||||
auto dim_x = ctx->GetInputDim("X"); // NCHW format
|
||||
int out_h = ctx->Attrs().Get<int>("out_h");
|
||||
int out_w = ctx->Attrs().Get<int>("out_w");
|
||||
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
|
||||
|
||||
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
|
||||
}
|
||||
};
|
||||
|
||||
class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
BilinearInterpOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X",
|
||||
"(Tensor) The input tensor of bilinear interpolation, "
|
||||
"This is a 4-D tensor with shape of (N x C x h x w)");
|
||||
AddOutput("Out",
|
||||
"(Tensor) The dimension of output is (N x C x out_h x out_w]");
|
||||
|
||||
AddAttr<int>("out_h", "(int) output height of bilinear interpolation op.");
|
||||
AddAttr<int>("out_w", "(int) output width of bilinear interpolation op.");
|
||||
AddComment(R"DOC(
|
||||
Bilinear interpolation is an extension of linear interpolation for
|
||||
interpolating functions of two variables (e.g. H-direction and
|
||||
W-direction in this op) on a rectilinear 2D grid.
|
||||
|
||||
The key idea is to perform linear interpolation first in one
|
||||
direction, and then again in the other direction.
|
||||
|
||||
For details, please refer to Wikipedia:
|
||||
https://en.wikipedia.org/wiki/Bilinear_interpolation
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class BilinearInterpOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should not be null");
|
||||
auto dim_x = ctx->GetInputDim("X");
|
||||
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp,
|
||||
ops::BilinearInterpOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>);
|
||||
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad,
|
||||
ops::BilinearInterpGradKernel<float>);
|
@ -0,0 +1,186 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/bilinear_interp_op.h"
|
||||
#include "paddle/fluid/platform/cuda_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
__global__ void KeBilinearInterpFw(
|
||||
const T* in, const size_t in_img_h, const size_t in_img_w,
|
||||
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
|
||||
const size_t out_img_w, const size_t output_h, const size_t output_w,
|
||||
const size_t num_channels, const T ratio_h, const T ratioW) {
|
||||
int nthreads = output_h * output_w;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < nthreads) {
|
||||
int out_id_h = tid / output_w;
|
||||
int out_id_w = tid % output_w;
|
||||
int in_img_size = input_w / num_channels;
|
||||
int out_img_size = output_w / num_channels;
|
||||
int channel_id = out_id_w / out_img_size;
|
||||
|
||||
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
|
||||
int in_img_idy = ratio_h * out_img_idy;
|
||||
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
|
||||
T h1lambda = ratio_h * out_img_idy - in_img_idy;
|
||||
T h2lambda = 1.f - h1lambda;
|
||||
|
||||
int out_img_idx = tid % out_img_w;
|
||||
int in_img_idx = ratioW * out_img_idx;
|
||||
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
|
||||
T w1lambda = ratioW * out_img_idx - in_img_idx;
|
||||
T w2lambda = 1.f - w1lambda;
|
||||
|
||||
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
|
||||
in_img_idy * in_img_w + in_img_idx];
|
||||
|
||||
// bilinear interpolation
|
||||
out[out_id_h * output_w + out_id_w] =
|
||||
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
|
||||
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
|
||||
w1lambda * in_pos[h_id * in_img_w + w_id]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void KeBilinearInterpBw(
|
||||
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
|
||||
const size_t input_w, const T* out, const size_t out_img_h,
|
||||
const size_t out_img_w, const size_t output_h, const size_t output_w,
|
||||
const size_t num_channels, const T ratio_h, const T ratioW) {
|
||||
int nthreads = output_h * output_w;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < nthreads) {
|
||||
int out_id_h = tid / output_w;
|
||||
int out_id_w = tid % output_w;
|
||||
int in_img_size = input_w / num_channels;
|
||||
int out_img_size = output_w / num_channels;
|
||||
int channel_id = out_id_w / out_img_size;
|
||||
|
||||
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
|
||||
int in_img_idy = ratio_h * out_img_idy;
|
||||
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
|
||||
T h1lambda = ratio_h * out_img_idy - in_img_idy;
|
||||
T h2lambda = 1.f - h1lambda;
|
||||
|
||||
int out_img_idx = tid % out_img_w;
|
||||
int in_img_idx = ratioW * out_img_idx;
|
||||
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
|
||||
T w1lambda = ratioW * out_img_idx - in_img_idx;
|
||||
T w2lambda = 1.f - w1lambda;
|
||||
|
||||
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
|
||||
in_img_idy * in_img_w + in_img_idx];
|
||||
const T* out_pos = &out[out_id_h * output_w + out_id_w];
|
||||
atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
|
||||
atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
|
||||
atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]);
|
||||
atomicAdd(&in_pos[h_id * in_img_w + w_id],
|
||||
h1lambda * w1lambda * out_pos[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
|
||||
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
|
||||
auto* input = input_t->data<T>();
|
||||
auto* output = output_t->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int out_h = ctx.Attr<int>("out_h");
|
||||
int out_w = ctx.Attr<int>("out_w");
|
||||
int batch_size = input_t->dims()[0];
|
||||
int channels = input_t->dims()[1];
|
||||
int in_h = input_t->dims()[2];
|
||||
int in_w = input_t->dims()[3];
|
||||
|
||||
int in_hw = in_h * in_w;
|
||||
int out_hw = out_h * out_w;
|
||||
int in_chw = channels * in_hw;
|
||||
int out_chw = channels * out_hw;
|
||||
|
||||
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
|
||||
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
|
||||
|
||||
if (in_h == out_h && in_w == out_w) {
|
||||
memcpy(output, input, input_t->numel() * sizeof(T));
|
||||
} else {
|
||||
int threadNum = batch_size * out_chw;
|
||||
int blocks = (threadNum + 1024 - 1) / 1024;
|
||||
|
||||
KeBilinearInterpFw<
|
||||
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
|
||||
input, in_h, in_w, batch_size, in_chw, output, out_h, out_w,
|
||||
batch_size, out_chw, channels, ratio_h, ratio_w);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
|
||||
auto* d_output = d_output_t->data<T>();
|
||||
|
||||
auto& device_ctx =
|
||||
ctx.template device_context<platform::CUDADeviceContext>();
|
||||
math::SetConstant<platform::CUDADeviceContext, T> zero;
|
||||
zero(device_ctx, d_input_t, static_cast<T>(0.0));
|
||||
|
||||
int out_h = ctx.Attr<int>("out_h");
|
||||
int out_w = ctx.Attr<int>("out_w");
|
||||
int batch_size = d_input_t->dims()[0];
|
||||
int channels = d_input_t->dims()[1];
|
||||
int in_h = d_input_t->dims()[2];
|
||||
int in_w = d_input_t->dims()[3];
|
||||
|
||||
int in_hw = in_h * in_w;
|
||||
int out_hw = out_h * out_w;
|
||||
int in_chw = channels * in_hw;
|
||||
int out_chw = channels * out_hw;
|
||||
|
||||
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
|
||||
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
|
||||
|
||||
if (in_h == out_h && in_w == out_w) {
|
||||
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
|
||||
} else {
|
||||
int threadNum = batch_size * out_chw;
|
||||
int blocks = (threadNum + 1024 - 1) / 1024;
|
||||
|
||||
KeBilinearInterpBw<
|
||||
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
|
||||
d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w,
|
||||
batch_size, out_chw, channels, ratio_h, ratio_w);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(bilinear_interp,
|
||||
ops::BilinearInterpOpCUDAKernel<float>);
|
||||
REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad,
|
||||
ops::BilinearInterpGradOpCUDAKernel<float>);
|
@ -0,0 +1,143 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
class BilinearInterpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
|
||||
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
|
||||
auto* input = input_t->data<T>();
|
||||
auto* output = output_t->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int out_h = ctx.Attr<int>("out_h");
|
||||
int out_w = ctx.Attr<int>("out_w");
|
||||
int batch_size = input_t->dims()[0];
|
||||
int channels = input_t->dims()[1];
|
||||
int in_h = input_t->dims()[2];
|
||||
int in_w = input_t->dims()[3];
|
||||
|
||||
int in_hw = in_h * in_w;
|
||||
int out_hw = out_h * out_w;
|
||||
int in_chw = channels * in_hw;
|
||||
int out_chw = channels * out_hw;
|
||||
|
||||
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
|
||||
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
|
||||
|
||||
if (in_h == out_h && in_w == out_w) {
|
||||
memcpy(output, input, input_t->numel() * sizeof(T));
|
||||
} else {
|
||||
for (int k = 0; k < batch_size; ++k) { // loop for batches
|
||||
for (int i = 0; i < out_h; ++i) { // loop for images
|
||||
int h = ratio_h * i;
|
||||
int hid = (h < in_h - 1) ? 1 : 0;
|
||||
T h1lambda = ratio_h * i - h;
|
||||
T h2lambda = 1 - h1lambda;
|
||||
|
||||
for (int j = 0; j < out_w; ++j) {
|
||||
int w = ratio_w * j;
|
||||
int wid = (w < in_w - 1) ? 1 : 0;
|
||||
T w1lambda = ratio_w * j - w;
|
||||
T w2lambda = 1 - w1lambda;
|
||||
// calculate four position for bilinear interpolation
|
||||
const T* in_pos = &input[k * in_chw + h * in_w + w];
|
||||
T* out_pos = &output[k * out_chw + i * out_w + j];
|
||||
|
||||
for (int c = 0; c < channels; ++c) { // loop for channels
|
||||
// bilinear interpolation
|
||||
out_pos[0] =
|
||||
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) +
|
||||
h1lambda * (w2lambda * in_pos[hid * in_w] +
|
||||
w1lambda * in_pos[hid * in_w + wid]);
|
||||
in_pos += in_hw;
|
||||
out_pos += out_hw;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BilinearInterpGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
|
||||
auto* d_output = d_output_t->data<T>();
|
||||
|
||||
auto& device_ctx =
|
||||
ctx.template device_context<platform::CPUDeviceContext>();
|
||||
math::SetConstant<platform::CPUDeviceContext, T> zero;
|
||||
zero(device_ctx, d_input_t, static_cast<T>(0.0));
|
||||
|
||||
int out_h = ctx.Attr<int>("out_h");
|
||||
int out_w = ctx.Attr<int>("out_w");
|
||||
int batch_size = d_input_t->dims()[0];
|
||||
int channels = d_input_t->dims()[1];
|
||||
int in_h = d_input_t->dims()[2];
|
||||
int in_w = d_input_t->dims()[3];
|
||||
|
||||
int in_hw = in_h * in_w;
|
||||
int out_hw = out_h * out_w;
|
||||
int in_chw = channels * in_hw;
|
||||
int out_chw = channels * out_hw;
|
||||
|
||||
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
|
||||
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
|
||||
|
||||
if (in_h == out_h && in_w == out_w) {
|
||||
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
|
||||
} else {
|
||||
for (int k = 0; k < batch_size; ++k) { // loop for batches
|
||||
for (int i = 0; i < out_h; ++i) { // loop for images
|
||||
int h = ratio_h * i;
|
||||
int hid = (h < in_h - 1) ? 1 : 0;
|
||||
T h1lambda = ratio_h * i - h;
|
||||
T h2lambda = 1 - h1lambda;
|
||||
|
||||
for (int j = 0; j < out_w; ++j) {
|
||||
int w = ratio_w * j;
|
||||
int wid = (w < in_w - 1) ? 1 : 0;
|
||||
T w1lambda = ratio_w * j - w;
|
||||
T w2lambda = 1 - w1lambda;
|
||||
T* in_pos = &d_input[k * in_chw + h * in_w + w];
|
||||
const T* out_pos = &d_output[k * out_chw + i * out_w + j];
|
||||
|
||||
for (int c = 0; c < channels; ++c) { // loop for channels
|
||||
in_pos[0] += h2lambda * w2lambda * out_pos[0];
|
||||
in_pos[wid] += h2lambda * w1lambda * out_pos[0];
|
||||
in_pos[hid * in_w] += h1lambda * w2lambda * out_pos[0];
|
||||
in_pos[hid * in_w + wid] += h1lambda * w1lambda * out_pos[0];
|
||||
in_pos += in_hw;
|
||||
out_pos += out_hw;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,95 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def bilinear_interp_np(input, out_h, out_w):
|
||||
batch_size, channel, in_h, in_w = input.shape
|
||||
if out_h > 1:
|
||||
ratio_h = (in_h - 1.0) / (out_h - 1.0)
|
||||
else:
|
||||
ratio_h = 0.0
|
||||
if out_w > 1:
|
||||
ratio_w = (in_w - 1.0) / (out_w - 1.0)
|
||||
else:
|
||||
ratio_w = 0.0
|
||||
|
||||
out = np.zeros((batch_size, channel, out_h, out_w))
|
||||
for i in range(out_h):
|
||||
h = int(ratio_h * i)
|
||||
hid = 1 if h < in_h - 1 else 0
|
||||
h1lambda = ratio_h * i - h
|
||||
h2lambda = 1.0 - h1lambda
|
||||
for j in range(out_w):
|
||||
w = int(ratio_w * j)
|
||||
wid = 1 if w < in_w - 1 else 0
|
||||
w1lambda = ratio_w * j - w
|
||||
w2lambda = 1.0 - w1lambda
|
||||
|
||||
out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] +
|
||||
w1lambda*input[:, :, h, w+wid]) + \
|
||||
h1lambda*(w2lambda*input[:, :, h+hid, w] +
|
||||
w1lambda*input[:, :, h+hid, w+wid])
|
||||
return out.astype("float32")
|
||||
|
||||
|
||||
class TestBilinearInterpOp(OpTest):
|
||||
def setUp(self):
|
||||
self.init_test_case()
|
||||
self.op_type = "bilinear_interp"
|
||||
input_np = np.random.random(self.input_shape).astype("float32")
|
||||
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w)
|
||||
|
||||
self.inputs = {'X': input_np}
|
||||
self.attrs = {'out_h': self.out_h, 'out_w': self.out_w}
|
||||
self.outputs = {'Out': output_np}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out', in_place=True)
|
||||
|
||||
def init_test_case(self):
|
||||
self.input_shape = [2, 3, 4, 4]
|
||||
self.out_h = 2
|
||||
self.out_w = 2
|
||||
|
||||
|
||||
class TestCase1(TestBilinearInterpOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [4, 1, 7, 8]
|
||||
self.out_h = 1
|
||||
self.out_w = 1
|
||||
|
||||
|
||||
class TestCase2(TestBilinearInterpOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [3, 3, 9, 6]
|
||||
self.out_h = 12
|
||||
self.out_w = 12
|
||||
|
||||
|
||||
class TestCase3(TestBilinearInterpOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [1, 1, 128, 64]
|
||||
self.out_h = 64
|
||||
self.out_w = 128
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue