Add trilinear_interp OP (#18711)

* add trilinear interp. test=develop

* fix unittest. test=develop

* add python api and test_layers. test=develop

* refine API.spec. test=develop

* fix format. test=develop

* add python API test. test=develop

* format code. test=develop

* refine code strcuture. test=develop

* fix format

* fix doc. test=develop

* fix converage. test=develop

* fix format. test=develop
padding_in_crf
Kaipeng Deng 6 years ago committed by GitHub
parent c2063217e7
commit f86fead693
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -173,10 +173,11 @@ paddle.fluid.layers.label_smooth (ArgSpec(args=['label', 'prior_dist', 'epsilon'
paddle.fluid.layers.roi_pool (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)), ('document', '49368d724023a66b41b0071be41c0ba5'))
paddle.fluid.layers.roi_align (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)), ('document', '9a7a3b88a4fae41d58d3ca9b10ba0591'))
paddle.fluid.layers.dice_loss (ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)), ('document', '7e8e4bf1f0f8612961ed113e8af8f0c5'))
paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', 'a29488d94d9a4bc4434d8a3529b4c6fe'))
paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', '8cfc4f69dbbedb687b6c20732aa8f09e'))
paddle.fluid.layers.image_resize_short (ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)), ('document', 'bd97ebfe4bdf5110a5fcb8ecb626a447'))
paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '548c7c2ead5771d15abbaad505f901e9'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', 'b7d810d1e251c5957c1efa6aa699d2d0'))
paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '832b2412652d84a6631b1012c6e2d18b'))
paddle.fluid.layers.resize_trilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '4836e98a634f6fbea26d0cdaa303f867'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '32ffc0e8818d7319ed1bf63a791e985d'))
paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', 'f985c9b66e3aec96fa753a8eb44c991c'))
paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab'))
paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'abe3f714120117a5a3d3e639853932bf'))

@ -20,6 +20,85 @@ namespace operators {
using framework::Tensor;
static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
"Input(X) dimension is 4");
int out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_h = static_cast<int>(dim_x[2] * scale);
out_w = static_cast<int>(dim_x[3] * scale);
// protect when input shape is -1
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
} else {
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
ctx->ShareLoD("X", "Out");
return;
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
}
static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE("trilinear" == interp_method,
"Interpolation method can only be \"trilinear\" when Input(X) "
"dimension is 5");
int out_d, out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_d = static_cast<int>(dim_x[2] * scale);
out_h = static_cast<int>(dim_x[3] * scale);
out_w = static_cast<int>(dim_x[4] * scale);
// protect when input shape is -1
out_d = out_d > 0 ? out_d : -1;
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
} else {
out_d = ctx->Attrs().Get<int>("out_d");
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_GT(out_d, 0, "out_d should be greater than 0.");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 3, "OutSize's dim[0] must be 3");
ctx->ShareLoD("X", "Out");
return;
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_d, out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
}
class InterpolateOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
@ -31,41 +110,17 @@ class InterpolateOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of InterpolationOp should not be null.");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\".");
auto dim_x = ctx->GetInputDim("X"); // NCHW format
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
int out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_h = static_cast<int>(dim_x[2] * scale);
out_w = static_cast<int>(dim_x[3] * scale);
// protect when input shape is -1
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
} else {
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
ctx->ShareLoD("X", "Out");
return;
PADDLE_ENFORCE(dim_x.size() == 4 || dim_x.size() == 5,
"Input(X) dimension must be 4 or 5");
if (dim_x.size() == 4) {
// shape check for 2D interpolate for input tensor shape NCHW
Interpolate2DInferShapeCheck(ctx);
} else { // dim_x.size() == 5
// shape check for 3D interpolate for input tensor shape NCDHW
Interpolate3DInferShapeCheck(ctx);
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
}
protected:
@ -81,22 +136,27 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X",
"The input tensor of interpolate operator, "
"This is a 4-D tensor with shape of [N, C, H, w].");
"This is a 4-D tensor with shape of [N, C, H, W] or a "
"5-D tensor with shape of [N, C, D, H, W].");
AddInput("OutSize",
"This is a 1-D tensor with two numbers to specify output size. "
"The first number is height and the second number is width.")
"It should be [output_height, output_width] when input is a 4-D "
"tensor and should be [output_depth, output_height, output_width] "
"when input is a 5-D tensor.")
.AsDispensable();
AddOutput("Out",
"The output tensor of interpolate operator, "
"This is a 4-D tensor with shape of [N, C, H, W].");
"This is a tensor in same rank with Input(X).");
AddAttr<int>("out_h", "output height of interpolate op.");
AddAttr<int>("out_w", "output width of interpolate op.");
AddAttr<int>("out_d", "output depth of interpolate op.").SetDefault(0);
AddAttr<int>("out_h", "output height of interpolate op.").SetDefault(0);
AddAttr<int>("out_w", "output width of interpolate op.").SetDefault(0);
AddAttr<float>("scale", "scale factor of interpolate op.").SetDefault(0.);
AddAttr<std::string>("interp_method",
"(string, default \"bilinear\"), interpolation "
"method, can be \"bilinear\" for "
"bilinear interpolation and \"nearest\" for nearest "
"bilinear interpolation, \"trilinear\" for trilinear "
"interpolation and \"nearest\" for nearest "
"neighbor interpolation.")
.SetDefault("bilinear");
AddAttr<bool>(
@ -127,6 +187,11 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
to perform linear interpolation first in one direction, and then
again in the other direction.
Trilinear interpolation is an extension of linear interpolation for
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Align_corners and align_mode are optinal parameters,the calculation method
of interpolation can be selected by them.
@ -183,6 +248,27 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Trilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = (D_{in}+0.5) * scale_{factor} - 0.5
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
For details of nearest neighbor interpolation, please refer to Wikipedia:
@ -190,6 +276,9 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation
)DOC");
}
};
@ -251,6 +340,10 @@ REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
@ -261,3 +354,8 @@ REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -205,6 +205,17 @@ class TestBilinearInterpCase6(TestBilinearInterpOp):
self.align_mode = 1
class TestBilinearInterpSame(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 128, 64]
self.out_h = 128
self.out_w = 64
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpActualShape(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'

@ -1295,16 +1295,74 @@ class TestBook(LayerTest):
x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32")
output = layers.resize_bilinear(x, out_shape=[12, 12])
return (output)
output = layers.resize_bilinear(x, scale=3)
def make_resize_bilinear_by_scale(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32")
output = layers.resize_bilinear(x, scale=1.5)
return (output)
def make_resize_nearest(self):
try:
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x1', shape=[3, 9, 6], dtype="float32")
output = layers.resize_nearest(x, out_shape=[12, 12])
except ValueError:
pass
try:
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(
name='x2', shape=[3, 9, 6, 7], dtype="float32")
output = layers.resize_nearest(x, out_shape=[12, 12, 12])
except ValueError:
pass
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32")
output = layers.resize_nearest(x, out_shape=[12, 12])
return (output)
output = layers.resize_nearest(x, scale=3)
def make_resize_nearest_by_scale(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x1', shape=[3, 9, 6], dtype="float32")
output = layers.resize_nearest(x, scale=1.8)
return (output)
def make_resize_trilinear(self):
try:
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x2', shape=[3, 9, 6], dtype="float32")
output = layers.resize_trilinear(x, out_shape=[12, 12, 12])
except ValueError:
pass
try:
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(
name='x', shape=[3, 9, 6, 7], dtype="float32")
output = layers.resize_trilinear(x, out_shape=[12, 12])
except ValueError:
pass
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x', shape=[3, 9, 6, 7], dtype="float32")
output = layers.resize_trilinear(x, out_shape=[12, 12, 12])
return (output)
def make_resize_trilinear_by_scale(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x', shape=[3, 9, 6, 7], dtype="float32")
output = layers.resize_trilinear(x, scale=2.1)
return (output)
def make_polygon_box_transform(self):

@ -176,6 +176,16 @@ class TestNearestNeighborInterpCase6(TestNearestInterpOp):
self.align_corners = True
class TestNearestNeighborInterpSame(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 128, 64]
self.out_h = 128
self.out_w = 64
self.scale = 0.
self.align_corners = True
class TestNearestNeighborInterpActualShape(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'

Loading…
Cancel
Save