add Grid Sampler Operator for STN.

fix_recordio_link
dengkaipeng 6 years ago committed by dengkaipeng
parent 79da263b11
commit 0bb0e0c10f

@ -175,6 +175,7 @@ paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dim
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))

@ -0,0 +1,125 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
using ScopedSpatialTransformerDescriptor =
platform::ScopedSpatialTransformerDescriptor;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
class CUDNNGridSampleOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output = ctx.Output<Tensor>("Output");
int n = input->dims()[0];
int c = input->dims()[1];
int h = input->dims()[2];
int w = input->dims()[3];
const int size[4] = {n, c, h, w};
const T* input_data = input->data<T>();
const T* grid_data = grid->data<T>();
T* output_data = output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
ScopedSpatialTransformerDescriptor st_desc;
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
st_desc.descriptor<T>(4, size);
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(output->dims()));
CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerForward(
handle, cudnn_st_desc, CudnnDataType<T>::kOne(), cudnn_input_desc, input_data,
grid_data, CudnnDataType<T>::kZero(), cudnn_output_desc, output_data));
}
};
template <typename T>
class CUDNNGridSampleGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
auto output_grad_dims = output_grad->dims();
const int n = output_grad_dims[0];
const int c = output_grad_dims[1];
const int h = output_grad_dims[2];
const int w = output_grad_dims[3];
const int size[4] = {n, c, h, w};
ScopedSpatialTransformerDescriptor st_dest;
cudnnSpatialTransformerDescriptor_t cudnn_st_dest =
st_dest.descriptor<T>(4, size);
const T* input_data = input->data<T>();
const T* grid_data = grid->data<T>();
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>(output_grad_dims, ctx.GetPlace());
T* grid_grad_data = grid_grad->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor input_grad_desc;
ScopedTensorDescriptor output_grad_desc;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_input_grad_desc = input_grad_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(input_grad->dims()));
cudnnTensorDescriptor_t cudnn_output_grad_desc = output_grad_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(output_grad->dims()));
CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerBackward(
handle, cudnn_st_dest, CudnnDataType<T>::kOne(),
cudnn_input_desc, input_data, CudnnDataType<T>::kZero(),
cudnn_input_grad_desc, input_grad_data, CudnnDataType<T>::kOne(),
cudnn_output_grad_desc, output_grad_data, grid_data,
CudnnDataType<T>::kZero(), grid_grad_data));
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(grid_sampler, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNGridSampleOpKernel<float>,
paddle::operators::CUDNNGridSampleOpKernel<double>);
REGISTER_OP_KERNEL(grid_sampler_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNGridSampleGradOpKernel<float>,
paddle::operators::CUDNNGridSampleGradOpKernel<double>);

@ -0,0 +1,147 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h"
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class GridSampleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of GridSampleOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grid"),
"Input(Grid) of GridSampleOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of GridSampleOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
PADDLE_ENFORCE(x_dims.size() == 4, "Input(X) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(grid_dims.size() == 4, "Input(Grid) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], "Input(X) and Input(Grid) dims[0] should be equal.");
PADDLE_ENFORCE_EQ(grid_dims[1], x_dims[2], "Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
PADDLE_ENFORCE_EQ(grid_dims[2], x_dims[3], "Input(X) dims[3] and Input(Grid) dims[2] should be equal.");
ctx->SetOutputDim("Output", x_dims);
ctx->ShareLoD("X", "Output");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_);
}
};
class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of GridSampleOp, "
"This is a 4-D tensor with shape of [N, C, H, W]");
AddInput(
"Grid",
"(Tensor) The output of AffineGridOp, "
"This is a 4-D tensor with shape of [N, H, W, 2]");
AddOutput(
"Output",
"(Tensor) Output tensor with shape [N, C, H, W]");
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(true);
AddComment(R"DOC(
It sample input X by grid gennerate by AffineGridOp.
)DOC");
}
};
class GridSampleOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
//TO DO
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_);
}
};
class GridSampleGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("grid_sampler_grad");
op->SetInput("X", Input("X"));
op->SetInput("Grid", Input("Grid"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Grid"), InputGrad("Grid"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker,
ops::GridSampleGradMaker);
REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad);
REGISTER_OP_CPU_KERNEL(
grid_sampler,
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
grid_sampler_grad,
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, double>);

File diff suppressed because it is too large Load Diff

@ -341,6 +341,28 @@ class ScopedPoolingDescriptor {
DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor); DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
}; };
class ScopedSpatialTransformerDescriptor {
public:
ScopedSpatialTransformerDescriptor() {
PADDLE_ENFORCE(dynload::cudnnCreateSpatialTransformerDescriptor(&desc_));
}
~ScopedSpatialTransformerDescriptor() {
PADDLE_ENFORCE(dynload::cudnnDestroySpatialTransformerDescriptor(desc_));
}
template <typename T>
inline cudnnSpatialTransformerDescriptor_t descriptor(const int nbDims,
const int dimA[]) {
PADDLE_ENFORCE(dynload::cudnnSetSpatialTransformerNdDescriptor(
desc_, CUDNN_SAMPLER_BILINEAR, CudnnDataType<T>::type, nbDims, dimA));
return desc_;
}
private:
cudnnSpatialTransformerDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedSpatialTransformerDescriptor);
};
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());

@ -90,6 +90,13 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnSetConvolutionNdDescriptor); \ __macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \ __macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnDeriveBNTensorDescriptor); \ __macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreateSpatialTransformerDescriptor); \
__macro(cudnnSetSpatialTransformerNdDescriptor); \
__macro(cudnnDestroySpatialTransformerDescriptor);\
__macro(cudnnSpatialTfGridGeneratorForward); \
__macro(cudnnSpatialTfGridGeneratorBackward); \
__macro(cudnnSpatialTfSamplerForward); \
__macro(cudnnSpatialTfSamplerBackward); \
__macro(cudnnCreate); \ __macro(cudnnCreate); \
__macro(cudnnDestroy); \ __macro(cudnnDestroy); \
__macro(cudnnSetStream); \ __macro(cudnnSetStream); \

@ -157,6 +157,7 @@ __all__ = [
'sequence_reverse', 'sequence_reverse',
'affine_channel', 'affine_channel',
'hash', 'hash',
'grid_sampler',
] ]
@ -7580,3 +7581,38 @@ def hash(input, hash_size, num_hash=1, name=None):
attrs={'num_hash': num_hash, attrs={'num_hash': num_hash,
'mod_by': hash_size}) 'mod_by': hash_size})
return out return out
@templatedoc()
def grid_sampler(x, grid):
"""
It sample data from input x by the given grid, insert data of each
point by bilinear interp.
Args:
x(Variable): Input data of shape [N, H, W, C]
grid(Variable): Input grid tensor of shape [N, H, W, 2]
Returns:
out(Variable): Output data indices by grid from x of shape [N, H, W, C]
"""
helper = LayerHelper("grid_sampler", **locals())
if not isinstance(x, Variable):
return ValueError("The x should be a Variable")
if not isinstance(grid, Variable):
return ValueError("The grid should be a Variable")
out = helper.create_tmp_variable(x.dtype)
ipts = {'X': x, 'Grid': grid}
attrs = {}
helper.apppend_op(
type='grid_sampler',
inputs=ipts,
outputs={'Output', out},
attrs = None if len(attrs) == 0 else attrs)
return 0

@ -0,0 +1,121 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
def AffineGrid(theta, size):
n = size[0]
h = size[2]
w = size[3]
h_idx = np.repeat(
np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis]
w_idx = np.repeat(
np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis]
grid = np.concatenate(
[w_idx, h_idx, np.ones([h, w, 1])], axis=2) # h * w * 3
grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * h * w *3
ret = np.zeros([n, h * w, 2])
theta = theta.transpose([0, 2, 1])
for i in range(len(theta)):
ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i])
# print ret.reshape([n, h * w, 2]).astype("float32")
return ret.reshape([n, h, w, 2]).astype("float32")
def getGridPointValue(data, x, y):
data_shape = data.shape
N = data_shape[0]
H = data_shape[2]
W = data_shape[3]
out = np.zeros(data_shape, dtype='float')
for i in range(N):
for j in range(H):
for k in range(W):
if y[i, j, k] < 0 or y[i, j, k] > H - 1 or x[i, j, k] < 0 or x[i, j, k] > W - 1:
out[i, :, j, k] = 0
else:
out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]]
return out
def GridSampler(data, grid):
dims = data.shape
N = dims[0]
C = dims[1]
H = dims[2]
W = dims[3]
x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
y_max = H - 1
x_max = W - 1
x = 0.5 * ((x.astype('float32') + 1.0) * x_max)
y = 0.5 * ((y.astype('float32') + 1.0) * y_max)
x0 = np.floor(x).astype('int32')
x1 = x0 + 1
y0 = np.floor(y).astype('int32')
y1 = y0 + 1
wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1))
wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, H, W)), (1, C, 1, 1))
wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1))
wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, H, W)), (1, C, 1, 1))
va = getGridPointValue(data, x0, y0)
vb = getGridPointValue(data, x0, y1)
vc = getGridPointValue(data, x1, y0)
vd = getGridPointValue(data, x1, y1)
out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float32')
return out
class TestGridSamplerOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'grid_sampler'
x = np.random.randint(0, 255, self.x_shape).astype('float32')
theta = np.zeros(self.theta_shape).astype('float32')
for i in range(self.theta_shape[0]):
for j in range(2):
for k in range(3):
theta[i, j, k] = np.random.rand(1)[0]
grid = AffineGrid(theta, self.x_shape)
self.inputs = {'X': x, 'Grid': grid}
self.attrs = {'use_cudnn': True}
self.outputs = {'Output': GridSampler(x, grid)}
# print self.outputs
def test_check_output(self):
self.check_output(atol=1e-3)
def test_check_grad_normal(self):
self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.6)
def initTestCase(self):
self.x_shape = (2, 5, 7, 3)
self.grid_shape = (2, 7, 3, 2)
self.theta_shape = (2, 2, 3)
if __name__ == "__main__":
unittest.main()

@ -865,6 +865,16 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out) self.assertIsNotNone(out)
print(str(program)) print(str(program))
def test_affine_grid_gen(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[2, 5, 7, 3 ], dtype='float32')
grid = layers.data(name='grid', shape=[2, 5, 7, 2], dtype='float32' )
out = layers.grid_sampler(x, grid)
self.assertIsNotNone(out)
print(str(program))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save