【2.0 API】Enhance affine grid operator (#26385)
* Enhance affine grid operator: 1. Add cuda kernel 2. Add align corners options test=develop * Move new affine_grid api to functional test=develop * Add CUDA kernel for affine_grid. test=develop * Add more unitest for grid sample API test=developrevert-26856-strategy_example2
parent
6f69fbc8ea
commit
a065a24232
@ -0,0 +1,209 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/affine_grid_op.h"
|
||||
#include "paddle/fluid/platform/cuda_device_function.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) {
|
||||
CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Linspace<paddle::platform::CUDADeviceContext, T> {
|
||||
void operator()(T start, T end, int count, bool align_corners,
|
||||
framework::Tensor* numbers,
|
||||
const framework::ExecutionContext& ctx) {
|
||||
T* number_data = numbers->mutable_data<T>({count}, ctx.GetPlace());
|
||||
T slice = (end - start) / (T)(count - 1);
|
||||
if (!align_corners) {
|
||||
slice = (end - start) / (T)count;
|
||||
start *= (T)(count - 1) / (T)count;
|
||||
}
|
||||
auto stream = ctx.cuda_device_context().stream();
|
||||
int block = 512;
|
||||
int grid = (count + block - 1) / block;
|
||||
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, slice, count,
|
||||
number_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__global__ void affine_grid_kernel(const int count, int n, int out_h, int out_w,
|
||||
T h_start, T w_start, T h_step, T w_step,
|
||||
const T* theta, // N, 2, 3
|
||||
T* output) {
|
||||
CUDA_KERNEL_LOOP(index, count) {
|
||||
int w = index % out_w;
|
||||
int h = (index / out_w) % out_h;
|
||||
int n = index / (out_w * out_h);
|
||||
|
||||
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
|
||||
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);
|
||||
|
||||
int theta_offset = n * 6; // 2 * 3;
|
||||
// affine from (h_coor, w_coor) to (x, y)
|
||||
output[index * 2] = theta[theta_offset] * h_coor +
|
||||
theta[theta_offset + 1] * w_coor +
|
||||
theta[theta_offset + 2];
|
||||
output[index * 2 + 1] = theta[theta_offset + 3] * h_coor +
|
||||
theta[theta_offset + 4] * w_coor +
|
||||
theta[theta_offset + 5];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void affine_grid_grad_kernel(const int count, int n, int out_h,
|
||||
int out_w, T h_start, T w_start,
|
||||
T h_step, T w_step,
|
||||
const T* out_grad, // N, H, W, 2
|
||||
T* theta_grad) { // N, 2, 3
|
||||
CUDA_KERNEL_LOOP(index, count) {
|
||||
int w = index % out_w;
|
||||
int h = (index / out_w) % out_h;
|
||||
int n = index / (out_w * out_h);
|
||||
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
|
||||
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);
|
||||
|
||||
int theta_offset = n * 6; // 2 * 3;
|
||||
T out_grad_x = out_grad[index * 2];
|
||||
atomicAdd(theta_grad + theta_offset, out_grad_x * h_coor);
|
||||
atomicAdd(theta_grad + theta_offset + 1, out_grad_x * w_coor);
|
||||
atomicAdd(theta_grad + theta_offset + 2, out_grad_x);
|
||||
|
||||
T out_grad_y = out_grad[index * 2 + 1];
|
||||
atomicAdd(theta_grad + theta_offset + 3, out_grad_y * h_coor);
|
||||
atomicAdd(theta_grad + theta_offset + 4, out_grad_y * w_coor);
|
||||
atomicAdd(theta_grad + theta_offset + 5, out_grad_y);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class AffineGridOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* theta = ctx.Input<Tensor>("Theta");
|
||||
int n = theta->dims()[0];
|
||||
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
|
||||
auto align_corners = ctx.Attr<bool>("align_corners");
|
||||
int h = 0;
|
||||
int w = 0;
|
||||
if (size_attr.size() == 0) {
|
||||
auto* output_shape = ctx.Input<Tensor>("OutputShape");
|
||||
Tensor h_sizes;
|
||||
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
|
||||
const int* h_size_data = h_sizes.data<int>();
|
||||
h = h_size_data[2];
|
||||
w = h_size_data[3];
|
||||
} else {
|
||||
h = size_attr[2];
|
||||
w = size_attr[3];
|
||||
}
|
||||
auto* output = ctx.Output<Tensor>("Output");
|
||||
T* out_data = output->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
|
||||
|
||||
T h_step;
|
||||
T w_step;
|
||||
T h_start = -1;
|
||||
T w_start = -1;
|
||||
if (align_corners) {
|
||||
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
|
||||
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
|
||||
} else {
|
||||
h_step = static_cast<T>(2) / static_cast<T>(h);
|
||||
w_step = static_cast<T>(2) / static_cast<T>(w);
|
||||
|
||||
h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
|
||||
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
|
||||
}
|
||||
|
||||
const int count = n * h * w;
|
||||
int block = 512;
|
||||
int grid = (count + block - 1) / block;
|
||||
auto cu_stream = ctx.cuda_device_context().stream();
|
||||
affine_grid_kernel<<<grid, block, 0, cu_stream>>>(
|
||||
count, n, h, w, h_start, w_start, h_step, w_step,
|
||||
theta->data<T>(), // N, 2, 3
|
||||
out_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AffineGridGradOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
|
||||
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
|
||||
int n = output_grad->dims()[0];
|
||||
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
|
||||
auto align_corners = ctx.Attr<bool>("align_corners");
|
||||
int h = 0;
|
||||
int w = 0;
|
||||
if (size_attr.size() == 0) {
|
||||
auto* output_shape = ctx.Input<Tensor>("OutputShape");
|
||||
Tensor h_sizes;
|
||||
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
|
||||
const int* h_size_data = h_sizes.data<int>();
|
||||
h = h_size_data[2];
|
||||
w = h_size_data[3];
|
||||
} else {
|
||||
h = size_attr[2];
|
||||
w = size_attr[3];
|
||||
}
|
||||
T* theta_grad_data = theta_grad->mutable_data<T>({n, 2, 3}, ctx.GetPlace());
|
||||
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
|
||||
ctx.cuda_device_context(), theta_grad, static_cast<T>(0));
|
||||
|
||||
T h_step;
|
||||
T w_step;
|
||||
T h_start = -1;
|
||||
T w_start = -1;
|
||||
if (align_corners) {
|
||||
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
|
||||
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
|
||||
} else {
|
||||
h_step = static_cast<T>(2) / static_cast<T>(h);
|
||||
w_step = static_cast<T>(2) / static_cast<T>(w);
|
||||
|
||||
h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
|
||||
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
|
||||
}
|
||||
const int count = n * h * w;
|
||||
VLOG(3) << "count: " << count << "; h_step: " << h_step
|
||||
<< "; w_step: " << w_step << "; h_start: " << h_start
|
||||
<< "; w_start: " << w_start;
|
||||
int block = 512;
|
||||
int grid = (count + block - 1) / block;
|
||||
auto cu_stream = ctx.cuda_device_context().stream();
|
||||
affine_grid_grad_kernel<<<grid, block, 0, cu_stream>>>(
|
||||
count, n, h, w, h_start, w_start, h_step, w_step,
|
||||
output_grad->data<T>(), theta_grad_data);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(affine_grid, ops::AffineGridOpCUDAKernel<float>,
|
||||
ops::AffineGridOpCUDAKernel<double>);
|
||||
REGISTER_OP_CUDA_KERNEL(affine_grid_grad,
|
||||
ops::AffineGridGradOpCUDAKernel<float>,
|
||||
ops::AffineGridGradOpCUDAKernel<double>);
|
@ -0,0 +1,149 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from paddle import fluid, nn
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.nn.functional as F
|
||||
import paddle.fluid.initializer as I
|
||||
import unittest
|
||||
|
||||
|
||||
class AffineGridTestCase(unittest.TestCase):
|
||||
def __init__(self,
|
||||
methodName='runTest',
|
||||
theta_shape=(20, 2, 3),
|
||||
output_shape=[20, 2, 5, 7],
|
||||
align_corners=True,
|
||||
dtype="float32",
|
||||
invalid_theta=False,
|
||||
variable_output_shape=False):
|
||||
super(AffineGridTestCase, self).__init__(methodName)
|
||||
|
||||
self.theta_shape = theta_shape
|
||||
self.output_shape = output_shape
|
||||
self.align_corners = align_corners
|
||||
self.dtype = dtype
|
||||
self.invalid_theta = invalid_theta
|
||||
self.variable_output_shape = variable_output_shape
|
||||
|
||||
def setUp(self):
|
||||
self.theta = np.random.randn(*(self.theta_shape)).astype(self.dtype)
|
||||
|
||||
def fluid_layer(self, place):
|
||||
# align_corners = True
|
||||
main = fluid.Program()
|
||||
start = fluid.Program()
|
||||
with fluid.unique_name.guard():
|
||||
with fluid.program_guard(main, start):
|
||||
theta_var = fluid.data(
|
||||
"input", self.theta_shape, dtype=self.dtype)
|
||||
y_var = fluid.layers.affine_grid(theta_var, self.output_shape)
|
||||
feed_dict = {"input": self.theta}
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(start)
|
||||
y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var])
|
||||
return y_np
|
||||
|
||||
def functional(self, place):
|
||||
main = fluid.Program()
|
||||
start = fluid.Program()
|
||||
with fluid.unique_name.guard():
|
||||
with fluid.program_guard(main, start):
|
||||
theta_var = fluid.data(
|
||||
"input", self.theta_shape, dtype=self.dtype)
|
||||
y_var = F.affine_grid(
|
||||
theta_var,
|
||||
self.output_shape,
|
||||
align_corners=self.align_corners)
|
||||
feed_dict = {"input": self.theta}
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(start)
|
||||
y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var])
|
||||
return y_np
|
||||
|
||||
def paddle_dygraph_layer(self):
|
||||
theta_var = dg.to_variable(
|
||||
self.theta) if not self.invalid_theta else "invalid"
|
||||
output_shape = dg.to_variable(
|
||||
self.
|
||||
output_shape) if self.variable_output_shape else self.output_shape
|
||||
y_var = F.affine_grid(
|
||||
theta_var, output_shape, align_corners=self.align_corners)
|
||||
y_np = y_var.numpy()
|
||||
return y_np
|
||||
|
||||
def _test_equivalence(self, place):
|
||||
place = fluid.CPUPlace()
|
||||
result1 = self.fluid_layer(place)
|
||||
result2 = self.functional(place)
|
||||
with dg.guard(place):
|
||||
result3 = self.paddle_dygraph_layer()
|
||||
if self.align_corners:
|
||||
np.testing.assert_array_almost_equal(result1, result2)
|
||||
np.testing.assert_array_almost_equal(result2, result3)
|
||||
|
||||
def runTest(self):
|
||||
place = fluid.CPUPlace()
|
||||
self._test_equivalence(place)
|
||||
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
place = fluid.CUDAPlace(0)
|
||||
self._test_equivalence(place)
|
||||
|
||||
|
||||
class AffineGridErrorTestCase(AffineGridTestCase):
|
||||
def runTest(self):
|
||||
place = fluid.CPUPlace()
|
||||
with dg.guard(place):
|
||||
with self.assertRaises(ValueError):
|
||||
self.paddle_dygraph_layer()
|
||||
|
||||
|
||||
def add_cases(suite):
|
||||
suite.addTest(AffineGridTestCase(methodName='runTest'))
|
||||
suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=True))
|
||||
|
||||
suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=False))
|
||||
suite.addTest(
|
||||
AffineGridTestCase(
|
||||
methodName='runTest', variable_output_shape=True))
|
||||
|
||||
suite.addTest(
|
||||
AffineGridTestCase(
|
||||
methodName='runTest',
|
||||
theta_shape=(20, 2, 3),
|
||||
output_shape=[20, 1, 7, 7],
|
||||
align_corners=True))
|
||||
|
||||
|
||||
def add_error_cases(suite):
|
||||
suite.addTest(
|
||||
AffineGridErrorTestCase(
|
||||
methodName='runTest', output_shape="not_valid"))
|
||||
suite.addTest(
|
||||
AffineGridErrorTestCase(
|
||||
methodName='runTest',
|
||||
invalid_theta=True)) # to test theta not variable error checking
|
||||
|
||||
|
||||
def load_tests(loader, standard_tests, pattern):
|
||||
suite = unittest.TestSuite()
|
||||
add_cases(suite)
|
||||
add_error_cases(suite)
|
||||
return suite
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue