Add bilateral_slice op (#25401)

* add bilateral slice op
fix_copy_if_different
LielinJiang 5 years ago committed by GitHub
parent 70554c9f97
commit 7129f544f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,194 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bilateral_slice_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using DataLayout = framework::DataLayout;
class BilateralSliceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BilateralSlice");
OP_INOUT_CHECK(ctx->HasInput("Grid"), "Input", "Grid", "BilateralSlice");
OP_INOUT_CHECK(ctx->HasInput("Guide"), "Input", "Guide", "BilateralSlice");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Output", "BilateralSlice");
auto dim_x = ctx->GetInputDim("X"); // NCHW format
PADDLE_ENFORCE_EQ(
dim_x.size(), 4,
platform::errors::Unimplemented(
"Input(X) dimension must be 4, but got dimension = %d .",
dim_x.size()));
auto input_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
auto guide_dims = ctx->GetInputDim("Guide");
bool has_offset = ctx->Attrs().Get<bool>("has_offset");
int64_t h = guide_dims[1];
int64_t w = guide_dims[2];
int64_t bs = grid_dims[0];
int64_t coeffs_chans = grid_dims[1];
int64_t input_chans = input_dims[1];
int64_t output_chans;
if (has_offset) {
PADDLE_ENFORCE_EQ((coeffs_chans % (input_chans + 1)), 0,
platform::errors::InvalidArgument(
"Slicing with affine offset, coefficients grid "
"should have n_out*(n_in+1) channels, but got %d",
coeffs_chans));
output_chans = coeffs_chans / (input_chans + 1);
} else {
PADDLE_ENFORCE_EQ((coeffs_chans % input_chans), 0,
platform::errors::InvalidArgument(
"Slicing without affine offset, coefficients grid "
"should have n_out*n_in channels, but got %d .",
coeffs_chans));
output_chans = coeffs_chans / input_chans;
}
std::vector<int64_t> output_dims;
output_dims.push_back(bs);
output_dims.push_back(output_chans);
output_dims.push_back(h);
output_dims.push_back(w);
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class BilateralSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor of bilateral_slice operator, "
"This is a 4-D tensor with shape of [N, C, H, W]");
AddInput("Grid",
"This is a 5-D tensor. "
"It should be [N, C, D, H, W].");
AddInput("Guide",
"This is a 3-D tensor "
"It should be [N, H, W].");
AddOutput("Out",
"The output tensor of bilateral slice operator, "
"This is a tensor in same rank with Input(X).");
AddAttr<bool>("has_offset", "an optional bool. Defaults to False. ")
.SetDefault(false);
AddComment(R"DOC(
This operator enhance input X according guide and grid
For details of bilateral slice, please refer to paper:
https://groups.csail.mit.edu/graphics/hdrnet/
)DOC");
}
};
class BilateralSliceOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BilateralSliceOpGrad");
OP_INOUT_CHECK(ctx->HasInput("Grid"), "Input", "Grid",
"BilateralSliceOpGrad");
OP_INOUT_CHECK(ctx->HasInput("Guide"), "Input", "Guide",
"BilateralSliceOpGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", "Out",
"BilateralSliceOpGrad");
auto dim_x = ctx->GetInputDim("X");
auto dim_grid = ctx->GetInputDim("Grid");
auto dim_guide = ctx->GetInputDim("Guide");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
}
if (ctx->HasOutput(framework::GradVarName("Grid"))) {
ctx->SetOutputDim(framework::GradVarName("Grid"), dim_grid);
}
if (ctx->HasOutput(framework::GradVarName("Guide"))) {
ctx->SetOutputDim(framework::GradVarName("Guide"), dim_guide);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
template <typename T>
class BilateralSliceGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Grid", this->Input("Grid"));
op->SetInput("Guide", this->Input("Guide"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Grid"), this->InputGrad("Grid"));
op->SetOutput(framework::GradVarName("Guide"), this->InputGrad("Guide"));
op->SetAttrMap(this->Attrs());
}
};
template <typename T>
class BilateralSliceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::Unimplemented(
"BilateralSlice only supports GPU now."));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(bilateral_slice, ops::BilateralSliceOp,
ops::BilateralSliceOpMaker,
ops::BilateralSliceGradMaker<paddle::framework::OpDesc>,
ops::BilateralSliceGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bilateral_slice_grad, ops::BilateralSliceOpGrad);
REGISTER_OP_CPU_KERNEL(bilateral_slice, ops::BilateralSliceKernel<float>,
ops::BilateralSliceKernel<double>);

File diff suppressed because it is too large Load Diff

@ -0,0 +1,33 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
struct GridSizes {
int64_t h;
int64_t w;
int64_t bs;
int64_t coeffs_chans;
int64_t gd;
int64_t gh;
int64_t gw;
int64_t input_chans;
};
} // namespace operators
} // namespace paddle

@ -31,9 +31,10 @@ struct GpuLaunchConfig {
};
inline GpuLaunchConfig getGpuLaunchConfig(
const int N, const framework::ExecutionContext& ctx) {
const int N, const framework::ExecutionContext& ctx,
int max_threads = 1024) {
int threads =
std::min(1024, ctx.cuda_device_context().GetMaxThreadsPerBlock());
std::min(max_threads, ctx.cuda_device_context().GetMaxThreadsPerBlock());
int physical_thread_count =
std::min(ctx.cuda_device_context().GetMaxPhysicalThreadCount(), N);
int blocks = std::min((physical_thread_count + threads - 1) / threads,

@ -35,7 +35,7 @@ __all__ = [
'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool',
'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat',
'partial_sum', 'tdm_child', 'rank_attention', 'tdm_sampler', 'batch_fc',
'_pull_box_extended_sparse'
'_pull_box_extended_sparse', 'bilateral_slice'
]
@ -1409,3 +1409,65 @@ def _pull_box_extended_sparse(input, size, extend_size=64, dtype='float32'):
if len(outs) == 1:
return outs[0], outs_extend[0]
return outs, outs_extend
def bilateral_slice(x, guide, grid, has_offset, name=None):
"""
:alias_main: paddle.nn.functional.bilateral_slice
:alias: paddle.nn.functional.bilateral_slice,paddle.nn.functional.vision.bilateral_slice
:old_api: paddle.fluid.layers.bilateral_slice
This operation implements bilateral slicing on the input according to the guide map.
For more information of bilateral slicing, please refer to Deep Bilateral Learning for Real-Time Image Enhancement <https://groups.csail.mit.edu/graphics/hdrnet/data/hdrnet.pdf>_
Args:
x(Variable): The input tensor, which is a 4-D tensor with shape
[N, C, H, W], N is the batch size, C is the channel
number, H and W is the feature height and width.
The data type is float32 and float64.
guide(Variable): Input grid tensor of shape [N, H, W]. The
data type is float32 and float64.
grid(Variable): Input grid tensor of shape [N, C, D, H, W]. The
data type is float32 and float64.
has_offset(bool): Whether to slice with affine offset.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Variable: Output of shape [N, C, H, W]. The data type is same as input tensor.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.data(name='x', shape=[None, 3, 101, 60], dtype='float32')
guide = fluid.data(name='guide', shape=[None, 101, 60], dtype='float32')
grid = fluid.data(name='grid', shape=[None, 12, 8, 10, 6], dtype='float32')
# without offset
output = fluid.layers.bilateral_slice(x, guide, grid, has_offset=False)
# has offset
output = fluid.layers.bilateral_slice(x, guide, grid, has_offset=True)
"""
helper = LayerHelper("bilateral_slice", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'bilateral_slice')
check_variable_and_dtype(guide, 'guide', ['float32', 'float64'],
'bilateral_slice')
check_variable_and_dtype(grid, 'grid', ['float32', 'float64'],
'bilateral_slice')
out = helper.create_variable_for_type_inference(x.dtype)
inputs = {'X': x, 'Guide': guide, 'Grid': grid}
helper.append_op(
type='bilateral_slice',
inputs=inputs,
attrs={'has_offset': has_offset},
outputs={'Out': out})
return out

@ -0,0 +1,194 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle
import math
class Gsz:
def __init__(self, h, w, gd, gh, gw, input_chans):
self.h = h
self.w = w
self.gd = gd
self.gh = gh
self.gw = gw
self.input_chans = input_chans
def diff_abs(x):
eps = 1e-8
return math.sqrt(x * x + eps)
def d_diff_abs(x):
eps = 1e-8
return x / math.sqrt(x * x + eps)
def weight_z(x):
abx = diff_abs(x)
return max(1.0 - abx, 0.0)
def d_weight_z(x):
abx = diff_abs(x)
if abx > 1.0:
return 0.0
else:
return d_diff_abs(x)
def naive_bilateral_slice_forward(output, grid, guide, input, gsz, has_offset,
total_count, output_chans):
h = gsz.h
w = gsz.w
gd = gsz.gd
gh = gsz.gh
gw = gsz.gw
input_chans = gsz.input_chans
coeff_stride = input_chans
grid_chans = input_chans * output_chans
if has_offset:
grid_chans += output_chans
coeff_stride += 1
for idx in range(total_count):
x = idx % w
y = idx // w % h
out_c = (idx // (h * w)) % output_chans
b = (idx // (output_chans * w * h))
gx = (x + 0.5) * gw / (1.0 * w)
gy = (y + 0.5) * gh / (1.0 * h)
gz = guide[int(b), int(y), int(x)] * gd
fx = int(np.floor(gx - 0.5))
fy = int(np.floor(gy - 0.5))
fz = int(np.floor(gz - 0.5))
value = 0.0
for in_c in range(0, coeff_stride):
coeff_sample = 0.0
for xx in range(fx, fx + 2):
x_ = max(min(xx, gw - 1), 0)
wx = max(1.0 - abs(xx + 0.5 - gx), 0.0)
for yy in range(fy, fy + 2):
y_ = max(min(yy, gh - 1), 0)
wy = max(1.0 - abs(yy + 0.5 - gy), 0.0)
for zz in range(fz, fz + 2):
z_ = max(min(zz, gd - 1), 0)
wz = weight_z(zz + 0.5 - gz)
c_ = coeff_stride * out_c + in_c
coeff_sample += grid[int(b), int(c_), int(z_), int(y_),
int(x_)] * wx * wy * wz
if in_c < input_chans:
value += coeff_sample * input[int(b), int(in_c), int(y), int(x)]
else:
value += coeff_sample
output[int(b), int(out_c), int(y), int(x)] = value
def naive_bilateral_slice(x, guide, grid, has_offset):
bs = x.shape[0]
h = x.shape[2]
w = x.shape[3]
input_chans = x.shape[1]
coeffs_chans = grid.shape[1]
if has_offset:
output_chans = coeffs_chans // (input_chans + 1)
else:
output_chans = coeffs_chans // input_chans
output = np.zeros([bs, int(output_chans), h, w]).astype(x.dtype)
gd = grid.shape[2]
gh = grid.shape[3]
gw = grid.shape[4]
gsz = Gsz(h, w, gd, gh, gw, input_chans)
total_count = bs * h * w * output.shape[1]
naive_bilateral_slice_forward(output, grid, guide, x, gsz, has_offset,
total_count, output.shape[1])
return output
@unittest.skipIf(not paddle.fluid.is_compiled_with_cuda(),
'CPU testing is not supported')
class TestBilateralSliceOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'bilateral_slice'
batch_size = 3
h = 50
w = 30
c = 1
gh = 5
gw = 3
gd = 2
gc = 2
x = np.random.rand(batch_size, c, h, w).astype(self.data_type)
guide = np.random.rand(batch_size, h, w).astype(self.data_type)
grid = np.random.rand(batch_size, gc, gd, gh, gw).astype(self.data_type)
output_np = naive_bilateral_slice(x, guide, grid, self.has_offset)
self.inputs = {'X': x, 'Grid': grid, 'Guide': guide}
self.attrs = {'has_offset': self.has_offset, }
self.outputs = {'Out': output_np}
def test_check_output(self):
place = paddle.fluid.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
self.check_output
def test_check_grad(self):
place = paddle.fluid.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def initTestCase(self):
self.has_offset = False
self.data_type = 'float64'
@unittest.skipIf(not paddle.fluid.is_compiled_with_cuda(),
'CPU testing is not supported')
class TestBilateralSliceOp1(TestBilateralSliceOp):
def initTestCase(self):
self.has_offset = True
self.data_type = 'float32'
class TestBilateralSliceApi(TestBilateralSliceOp):
def test_api(self):
x = paddle.fluid.data(
name='x', shape=[None, 3, 25, 15], dtype='float32')
guide = paddle.fluid.data(
name='guide', shape=[None, 25, 15], dtype='float32')
grid = paddle.fluid.data(
name='grid', shape=[None, 12, 8, 5, 3], dtype='float32')
paddle.fluid.contrib.layers.bilateral_slice(x, guide, grid,
self.has_offset)
if __name__ == "__main__":
unittest.main()

@ -74,7 +74,8 @@ NO_FP64_CHECK_GRAD_OP_LIST = [
'transpose2', \
'trilinear_interp', \
'var_conv_2d', \
'warpctc'
'warpctc', \
'bilateral_slice'
]
NO_FP16_CHECK_GRAD_OP_LIST = [

@ -40,7 +40,8 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'teacher_student_sigmoid_loss', \
'unpool', \
'yolov3_loss', \
'inverse'
'inverse', \
'bilateral_slice'
]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp']

Loading…
Cancel
Save