Optimization of Kernels that related to DeepLabv3+ (#13534)

* refine reduce by cub
* optimize KernelDepthwiseConvFilterGrad
* optimize depthwise conv and reduce mean and reduce sum
* fix bug: dilation
* cuda arch and cuda 8 compatible
revert-13637-optimize-opyreader
Dun 6 years ago committed by qingqing01
parent 35b713c3fd
commit 161c3e31f7

@ -301,6 +301,7 @@ op_library(fusion_lstm_op DEPS cpu_lstm_compute)
if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col)
op_library(layer_norm_op DEPS cub)
op_library(reduce_mean_op DEPS cub)
else()
op_library(conv_op DEPS vol2col im2col)
endif()

@ -380,7 +380,8 @@ class DepthwiseConvKernel : public framework::OpKernel<T> {
math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
auto& dev_ctx = context.template device_context<DeviceContext>();
depthwiseConv(dev_ctx, *input, filter, strides, paddings, output);
depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
output);
}
};
@ -415,14 +416,14 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> {
input_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, input_grad, static_cast<T>(0));
depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
paddings, input_grad);
paddings, dilations, input_grad);
}
if (filter_grad) {
filter_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides, paddings,
filter_grad);
dilations, filter_grad);
}
}
};

@ -345,7 +345,7 @@ class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
math::DepthwiseConvInputGradFunctor<DeviceContext, T>
depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, *output, filter, *input, strides, paddings,
output);
dilations, output);
}
};
@ -367,10 +367,11 @@ class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
if (input_grad) {
math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
depthwiseConv(dev_ctx, *output_grad, filter, strides, paddings,
depthwiseConv(dev_ctx, *output_grad, filter, strides, paddings, dilations,
input_grad);
}
@ -382,7 +383,7 @@ class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, *output_grad, *input, strides, paddings,
filter_grad);
dilations, filter_grad);
}
}
};

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -32,7 +32,8 @@ class DepthwiseConvFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* output);
const std::vector<int>& paddings,
const std::vector<int>& dilations, framework::Tensor* output);
};
template <typename DeviceContext, typename T>
@ -43,6 +44,7 @@ class DepthwiseConvInputGradFunctor {
const framework::Tensor& output_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
framework::Tensor* input_grad);
};
@ -53,6 +55,7 @@ class DepthwiseConvFilterGradFunctor {
const framework::Tensor& output_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
framework::Tensor* filter_grad);
};

@ -12,17 +12,64 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "paddle/fluid/operators/cub_reduce.h"
#include "paddle/fluid/operators/reduce_mean_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::MeanFunctor>);
namespace paddle {
namespace operators {
template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
private:
T n_inv;
};
template <typename T>
class ReduceMeanKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
auto dims = context.Attr<std::vector<int>>("dim");
bool keep_dim = context.Attr<bool>("keep_dim");
std::vector<int> reduce_dims;
if (reduce_all) {
reduce_dims.resize(input->dims().size());
for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i;
} else {
for (auto e : dims) {
reduce_dims.push_back(e >= 0 ? e : e + input->dims().size());
}
}
int reduce_num = 1;
for (int i = 0; i < reduce_dims.size(); ++i) {
reduce_num *= input->dims()[reduce_dims[i]];
}
auto stream = context.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, DivideFunctor<T>>(
*input, output, reduce_dims, static_cast<T>(0), cub::Sum(),
DivideFunctor<T>(reduce_num), stream);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<float>,
ops::ReduceMeanKernel<double>,
ops::ReduceMeanKernel<int>,
ops::ReduceMeanKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL(
reduce_mean_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::MeanGradFunctor>,

@ -12,17 +12,59 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/cub_reduce.h"
#include "paddle/fluid/operators/reduce_sum_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_sum,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::SumFunctor>);
namespace paddle {
namespace operators {
template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline T operator()(const T& x) const { return x; }
};
template <typename T>
class ReduceSumKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
auto dims = context.Attr<std::vector<int>>("dim");
bool keep_dim = context.Attr<bool>("keep_dim");
std::vector<int> reduce_dims;
if (reduce_all) {
reduce_dims.resize(input->dims().size());
for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i;
} else {
for (auto e : dims) {
reduce_dims.push_back(e >= 0 ? e : e + input->dims().size());
}
}
int reduce_num = 1;
for (int i = 0; i < reduce_dims.size(); ++i) {
reduce_num *= input->dims()[reduce_dims[i]];
}
auto stream = context.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
*input, output, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor<T>(), stream);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<float>,
ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>,
ops::ReduceSumKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL(
reduce_sum_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::SumGradFunctor>,

@ -67,6 +67,7 @@ class TestConv2dOp(OpTest):
def setUp(self):
self.op_type = "conv2d"
self.use_cudnn = False
self.use_cuda = False
self.use_mkldnn = False
self.data_format = "AnyLayout"
self.dtype = np.float32
@ -101,24 +102,25 @@ class TestConv2dOp(OpTest):
}
self.outputs = {'Output': output}
def testcudnn(self):
return core.is_compiled_with_cuda() and self.use_cudnn
def testcuda(self):
return core.is_compiled_with_cuda() and (self.use_cudnn or
self.use_cuda)
def test_check_output(self):
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
place = core.CUDAPlace(0) if self.testcuda() else core.CPUPlace()
self.check_output_with_place(place, atol=1e-5)
def test_check_grad(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
place = core.CUDAPlace(0) if self.testcuda() else core.CPUPlace()
self.check_grad_with_place(
place, set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
place = core.CUDAPlace(0) if self.testcuda() else core.CPUPlace()
self.check_grad_with_place(
place, ['Input'],
'Output',
@ -128,7 +130,7 @@ class TestConv2dOp(OpTest):
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
place = core.CUDAPlace(0) if self.testcuda() else core.CPUPlace()
self.check_grad_with_place(
place, ['Filter'],
'Output',
@ -325,18 +327,33 @@ class TestFP16CUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
class TestDepthwiseConv(TestConv2dOp):
def init_test_case(self):
self.use_cuda = True
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.filter_size = [3, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestDepthwiseConv2(TestConv2dOp):
def init_test_case(self):
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestDepthwiseConv3(TestConv2dOp):
def init_test_case(self):
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
@ -347,6 +364,34 @@ class TestDepthwiseConv2(TestConv2dOp):
self.op_type = "depthwise_conv2d"
class TestDepthwiseConvWithDilation(TestConv2dOp):
def init_test_case(self):
self.use_cuda = True
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.dilations = [2, 2]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestDepthwiseConvWithDilation2(TestConv2dOp):
def init_test_case(self):
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.dilations = [2, 2]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
# Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation):

Loading…
Cancel
Save