commit
ee0fd78c81
@ -0,0 +1,48 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/operators/conv_op.h"
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include "paddle/fluid/platform/cudnn_helper.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// This fused conv follows the equation:
|
||||
// y = act ( alpha1 * conv(x) + alpha2 * z + bias ).
|
||||
// here, y is Output,
|
||||
// x is Input,
|
||||
// z is ResidualData,
|
||||
// bias is Bias
|
||||
class Conv2DFusionOpMaker : public Conv2DOpMaker {
|
||||
protected:
|
||||
void Apply() override {
|
||||
AddAttr<std::string>(
|
||||
"activation",
|
||||
"The activation type can be 'identity', 'sigmoid', 'relu', 'relu6' "
|
||||
"'relux' , 'tanh', 'band_pass'")
|
||||
.SetDefault("relu");
|
||||
}
|
||||
};
|
||||
// TODO(qingqing): add gradient operator for conv2d_fusion
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(conv2d_fusion, ops::ConvOp, ops::Conv2DFusionOpMaker,
|
||||
ops::ConvOpInferVarType, paddle::framework::EmptyGradOpMaker);
|
@ -0,0 +1,187 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
|
||||
#include "paddle/fluid/platform/cudnn_helper.h"
|
||||
|
||||
DECLARE_uint64(conv_workspace_size_limit);
|
||||
DECLARE_bool(cudnn_exhaustive_search);
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
|
||||
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
|
||||
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
|
||||
using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
|
||||
using DataLayout = platform::DataLayout;
|
||||
template <typename T>
|
||||
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
|
||||
|
||||
template <typename T>
|
||||
class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto* input = ctx.Input<Tensor>("Input");
|
||||
auto* filter = ctx.Input<Tensor>("Filter");
|
||||
auto* bias = ctx.Input<Tensor>("Bias");
|
||||
PADDLE_ENFORCE(bias, "The bias should not be null.");
|
||||
auto* residual = ctx.Input<Tensor>("ResidualData");
|
||||
auto* output = ctx.Output<Tensor>("Output");
|
||||
|
||||
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
|
||||
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
|
||||
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
|
||||
const std::string activation = ctx.Attr<std::string>("activation");
|
||||
int groups = ctx.Attr<int>("groups");
|
||||
int64_t user_workspace_size =
|
||||
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
|
||||
bool exhaustive_search =
|
||||
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
|
||||
|
||||
const T* input_data = input->data<T>();
|
||||
const T* filter_data = filter->data<T>();
|
||||
const T* bias_data = bias->data<T>();
|
||||
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
||||
const T* residual_data = residual ? residual->data<T>() : output_data;
|
||||
|
||||
// ------------------- cudnn descriptors ---------------------
|
||||
ScopedTensorDescriptor input_desc;
|
||||
ScopedTensorDescriptor output_desc;
|
||||
ScopedFilterDescriptor filter_desc;
|
||||
ScopedTensorDescriptor bias_desc;
|
||||
ScopedConvolutionDescriptor conv_desc;
|
||||
ScopedActivationDescriptor act_desc;
|
||||
DataLayout layout = DataLayout::kNCHW;
|
||||
if (input->dims().size() == 5) {
|
||||
layout = DataLayout::kNCDHW;
|
||||
}
|
||||
|
||||
cudnnConvolutionDescriptor_t cudnn_conv_desc =
|
||||
conv_desc.descriptor<T>(paddings, strides, dilations);
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
|
||||
cudnn_conv_desc, groups));
|
||||
|
||||
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
|
||||
layout, framework::vectorize2int(input->dims()));
|
||||
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
|
||||
layout, framework::vectorize2int(output->dims()));
|
||||
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
|
||||
layout, framework::vectorize2int(filter->dims()));
|
||||
// Now only support NCHW
|
||||
std::vector<int> bias_dim = {1, static_cast<int>(output->dims()[1]), 1, 1};
|
||||
cudnnTensorDescriptor_t cudnn_bias_desc =
|
||||
bias_desc.descriptor<T>(layout, bias_dim);
|
||||
cudnnActivationDescriptor_t cudnn_act_desc =
|
||||
act_desc.descriptor<T>(activation);
|
||||
|
||||
// ------------------- cudnn conv workspace ---------------------
|
||||
size_t workspace_size_in_bytes; // final workspace to allocate.
|
||||
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
|
||||
if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
|
||||
int64_t max_user_size =
|
||||
std::max(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
|
||||
user_workspace_size);
|
||||
workspace_size_limit = max_user_size * 1024 * 1024;
|
||||
}
|
||||
|
||||
// ------------------- cudnn conv algorithm ---------------------
|
||||
cudnnConvolutionFwdAlgo_t algo;
|
||||
auto handle = dev_ctx.cudnn_handle();
|
||||
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
|
||||
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
||||
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
|
||||
|
||||
auto x_dims = framework::vectorize(input->dims());
|
||||
auto f_dims = framework::vectorize(filter->dims());
|
||||
if (activation == "identity") {
|
||||
// Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
|
||||
// enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib.
|
||||
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
|
||||
} else if (!exhaustive_search) {
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
|
||||
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
|
||||
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
||||
workspace_size_limit, &algo));
|
||||
VLOG(3) << "cuDNN forward algo " << algo;
|
||||
} else {
|
||||
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* algo_cache = nullptr;
|
||||
if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) {
|
||||
algo_cache =
|
||||
ctx.scope()
|
||||
.FindVar(kCUDNNFwdAlgoCache)
|
||||
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
|
||||
} else {
|
||||
algo_cache =
|
||||
const_cast<framework::Scope&>(ctx.scope())
|
||||
.Var(kCUDNNFwdAlgoCache)
|
||||
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
|
||||
}
|
||||
algo = algo_cache->GetAlgorithm(
|
||||
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
|
||||
int returned_algo_count;
|
||||
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
|
||||
fwd_perf_stat;
|
||||
auto cudnn_find_func = [&](void* cudnn_workspace) {
|
||||
CUDNN_ENFORCE(
|
||||
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
|
||||
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
|
||||
filter_data, cudnn_conv_desc, cudnn_output_desc,
|
||||
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
|
||||
fwd_perf_stat.data(), cudnn_workspace,
|
||||
workspace_size_limit));
|
||||
};
|
||||
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
|
||||
VLOG(3) << "Perf result: (algo: stat, time, memory)";
|
||||
for (int i = 0; i < returned_algo_count; ++i) {
|
||||
const auto& stat = fwd_perf_stat[i];
|
||||
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
|
||||
<< " " << stat.memory;
|
||||
}
|
||||
return fwd_perf_stat[0].algo;
|
||||
});
|
||||
VLOG(3) << "choose algo " << algo;
|
||||
}
|
||||
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
|
||||
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
|
||||
cudnn_output_desc, algo, &workspace_size_in_bytes));
|
||||
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
|
||||
"workspace_size to be allocated exceeds the limit");
|
||||
|
||||
// ------------------- cudnn conv+bias+act forward --------------------
|
||||
ScalingParamType<T> alpha1 = 1.0f;
|
||||
ScalingParamType<T> alpha2 = residual ? 1.0f : 0.0f;
|
||||
auto cudnn_func = [&](void* cudnn_workspace) {
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward(
|
||||
handle, &alpha1, cudnn_input_desc, input_data, cudnn_filter_desc,
|
||||
filter_data, cudnn_conv_desc, algo, cudnn_workspace,
|
||||
workspace_size_in_bytes, &alpha2, cudnn_output_desc, residual_data,
|
||||
cudnn_bias_desc, bias_data, cudnn_act_desc, cudnn_output_desc,
|
||||
output_data));
|
||||
};
|
||||
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>,
|
||||
ops::CUDNNConvFusionOpKernel<double>);
|
@ -0,0 +1,158 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest
|
||||
|
||||
from test_conv2d_op import conv2d_forward_naive
|
||||
|
||||
|
||||
class TestConv2dFusionOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "conv2d_fusion"
|
||||
self.exhaustive_search = False
|
||||
self.data_format = "AnyLayout"
|
||||
self.dtype = np.float32
|
||||
self.activation = 'relu'
|
||||
self.add_bias = True
|
||||
self.add_residual_data = True
|
||||
|
||||
self.init_group()
|
||||
self.init_dilation()
|
||||
self.init_test_case()
|
||||
self.init_bias_residual()
|
||||
self.init_activation()
|
||||
self.set_search_method()
|
||||
|
||||
conv2d_param = {
|
||||
'stride': self.stride,
|
||||
'pad': self.pad,
|
||||
'dilation': self.dilations
|
||||
}
|
||||
|
||||
input = np.random.random(self.input_size).astype(self.dtype)
|
||||
filter = np.random.random(self.filter_size).astype(self.dtype)
|
||||
|
||||
output = conv2d_forward_naive(input, filter, self.groups,
|
||||
conv2d_param).astype(self.dtype)
|
||||
|
||||
self.inputs = {
|
||||
'Input': OpTest.np_dtype_to_fluid_dtype(input),
|
||||
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
|
||||
}
|
||||
|
||||
if self.add_residual_data:
|
||||
residual_data = np.random.random(output.shape).astype(self.dtype)
|
||||
self.inputs['ResidualData'] = OpTest.np_dtype_to_fluid_dtype(
|
||||
residual_data)
|
||||
output += residual_data
|
||||
|
||||
if self.add_bias:
|
||||
bias = np.random.random(self.filter_size[0]).astype(self.dtype)
|
||||
self.inputs['Bias'] = OpTest.np_dtype_to_fluid_dtype(bias)
|
||||
output = output + bias.reshape((1, bias.size, 1, 1))
|
||||
|
||||
assert self.activation in ['relu', 'identity']
|
||||
if self.activation == 'relu':
|
||||
output = np.maximum(output, 0)
|
||||
|
||||
self.attrs = {
|
||||
'strides': self.stride,
|
||||
'paddings': self.pad,
|
||||
'groups': self.groups,
|
||||
'dilations': self.dilations,
|
||||
'data_format': self.data_format,
|
||||
'exhaustive_search': self.exhaustive_search,
|
||||
'activation': self.activation
|
||||
}
|
||||
self.outputs = {'Output': output}
|
||||
|
||||
def testcuda(self):
|
||||
return core.is_compiled_with_cuda()
|
||||
|
||||
def test_check_output(self):
|
||||
if self.testcuda():
|
||||
place = core.CUDAPlace(0)
|
||||
self.check_output_with_place(place, atol=1e-5)
|
||||
else:
|
||||
pass
|
||||
|
||||
def init_test_case(self):
|
||||
self.pad = [0, 0]
|
||||
self.stride = [1, 1]
|
||||
self.input_size = [2, 3, 5, 5] # NCHW
|
||||
assert np.mod(self.input_size[1], self.groups) == 0
|
||||
f_c = self.input_size[1] // self.groups
|
||||
self.filter_size = [6, f_c, 3, 3]
|
||||
|
||||
def init_dilation(self):
|
||||
self.dilations = [1, 1]
|
||||
|
||||
def init_group(self):
|
||||
self.groups = 1
|
||||
|
||||
def init_bias_residual(self):
|
||||
self.add_bias = True
|
||||
self.add_residual_data = True
|
||||
|
||||
def init_activation(self):
|
||||
self.activation = 'relu'
|
||||
|
||||
def set_search_method(self):
|
||||
self.exhaustive_search = False
|
||||
|
||||
|
||||
class TestWithoutResidual(TestConv2dFusionOp):
|
||||
def init_bias_residual(self):
|
||||
self.add_residual_data = False
|
||||
|
||||
|
||||
class TestIdentityActivation(TestConv2dFusionOp):
|
||||
def init_activation(self):
|
||||
self.activation = 'identity'
|
||||
|
||||
|
||||
class TestWithGroup(TestConv2dFusionOp):
|
||||
def init_group(self):
|
||||
self.groups = 3
|
||||
|
||||
|
||||
class TestWithDilation(TestConv2dFusionOp):
|
||||
def init_test_case(self):
|
||||
self.pad = [0, 0]
|
||||
self.stride = [1, 1]
|
||||
self.input_size = [2, 3, 10, 10] # NCHW
|
||||
assert np.mod(self.input_size[1], self.groups) == 0
|
||||
f_c = self.input_size[1] // self.groups
|
||||
self.filter_size = [6, f_c, 3, 3]
|
||||
|
||||
def init_dilation(self):
|
||||
self.dilations = [2, 2]
|
||||
|
||||
def init_group(self):
|
||||
self.groups = 3
|
||||
|
||||
|
||||
class TestCUDNNExhaustiveSearch(TestConv2dFusionOp):
|
||||
def set_search_method(self):
|
||||
self.exhaustive_search = True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue