|
|
|
@ -12,7 +12,6 @@
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "glog/logging.h"
|
|
|
|
|
#include "paddle/framework/eigen.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/memory/memory.h"
|
|
|
|
@ -69,13 +68,6 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
cudnnConvolutionDescriptor_t cudnn_conv_desc =
|
|
|
|
|
conv_desc.descriptor<T>(paddings, strides, dilations);
|
|
|
|
|
|
|
|
|
|
int input_channels = input->dims()[1]; // M
|
|
|
|
|
int input_height = input->dims()[2]; // H
|
|
|
|
|
int input_width = input->dims()[3]; // W
|
|
|
|
|
int output_channels = output->dims()[1]; // C
|
|
|
|
|
int output_height = output->dims()[2]; // O_H
|
|
|
|
|
int output_width = output->dims()[3]; // O_W
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn conv workspace ---------------------
|
|
|
|
|
void* cudnn_workspace = nullptr;
|
|
|
|
|
size_t workspace_size_in_bytes; // final workspace to allocate.
|
|
|
|
@ -118,7 +110,6 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
template <typename T>
|
|
|
|
|
class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -130,7 +121,6 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
|
|
|
|
|
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
|
|
|
|
|
auto filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
|
|
|
|
|
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
const T* output_grad_data = output_grad->data<T>();
|
|
|
|
|
const T* filter_data = filter->data<T>();
|
|
|
|
@ -138,47 +128,33 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
|
|
|
|
|
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
|
|
|
|
|
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
|
|
|
|
|
int groups = ctx.Attr<int>("groups");
|
|
|
|
|
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn descriptors ---------------------
|
|
|
|
|
ScopedTensorDescriptor input_desc;
|
|
|
|
|
ScopedTensorDescriptor output_grad_desc;
|
|
|
|
|
ScopedTensorDescriptor input_grad_desc;
|
|
|
|
|
|
|
|
|
|
ScopedTensorDescriptor output_desc;
|
|
|
|
|
ScopedFilterDescriptor filter_desc;
|
|
|
|
|
ScopedFilterDescriptor filter_grad_desc;
|
|
|
|
|
ScopedConvolutionDescriptor conv_desc;
|
|
|
|
|
DataLayout layout = DataLayout::kNCHW;
|
|
|
|
|
|
|
|
|
|
// Input: (N, M, H, W)
|
|
|
|
|
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(input->dims()), groups);
|
|
|
|
|
cudnnTensorDescriptor_t cudnn_output_grad_desc =
|
|
|
|
|
output_grad_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(output_grad->dims()), groups);
|
|
|
|
|
layout, framework::vectorize2int(input->dims()));
|
|
|
|
|
// Output: (N, C, O_H, O_W)
|
|
|
|
|
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(output_grad->dims()));
|
|
|
|
|
// Filter (M, C, K_H, K_W)
|
|
|
|
|
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(filter->dims()), groups);
|
|
|
|
|
cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr;
|
|
|
|
|
cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr;
|
|
|
|
|
layout, framework::vectorize2int(filter->dims()));
|
|
|
|
|
|
|
|
|
|
cudnnConvolutionDescriptor_t cudnn_conv_desc =
|
|
|
|
|
conv_desc.descriptor<T>(paddings, strides, dilations);
|
|
|
|
|
|
|
|
|
|
int input_channels = input->dims()[1];
|
|
|
|
|
int input_height = input->dims()[2];
|
|
|
|
|
int input_width = input->dims()[3];
|
|
|
|
|
int output_grad_channels = filter->dims()[0];
|
|
|
|
|
int output_grad_height = output_grad->dims()[2];
|
|
|
|
|
int output_grad_width = output_grad->dims()[3];
|
|
|
|
|
|
|
|
|
|
int group_offset_in = input_channels / groups * input_height * input_width;
|
|
|
|
|
int group_offset_out =
|
|
|
|
|
output_grad_channels / groups * output_grad_height * output_grad_width;
|
|
|
|
|
int group_offset_filter = filter->numel() / groups;
|
|
|
|
|
// ------------------- cudnn backward algorithm ---------------------
|
|
|
|
|
cudnnConvolutionBwdDataAlgo_t data_algo;
|
|
|
|
|
cudnnConvolutionFwdAlgo_t data_algo;
|
|
|
|
|
cudnnConvolutionBwdFilterAlgo_t filter_algo;
|
|
|
|
|
size_t workspace_size_in_bytes = 0, tmp_size = 0;
|
|
|
|
|
size_t bwd_filter_ws_size, fwd_ws_size;
|
|
|
|
|
size_t workspace_size_in_bytes = 0;
|
|
|
|
|
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
|
|
|
|
|
if (user_workspace_size > 0) {
|
|
|
|
|
workspace_size_limit = user_workspace_size * 1024 * 1024;
|
|
|
|
@ -186,42 +162,35 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto handle = ctx.cuda_device_context().cudnn_handle();
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
cudnn_input_grad_desc = input_grad_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(input_grad->dims()), groups);
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
|
handle, cudnn_filter_desc,
|
|
|
|
|
// dyDesc: Handle to the previously initialized input differential
|
|
|
|
|
// tensor descriptor.
|
|
|
|
|
cudnn_output_grad_desc, cudnn_conv_desc,
|
|
|
|
|
// dxDesc: Handle to the previously initialized output tensor
|
|
|
|
|
// descriptor.
|
|
|
|
|
cudnn_input_grad_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &data_algo));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
|
|
|
|
|
handle, cudnn_filter_desc, cudnn_output_grad_desc,
|
|
|
|
|
cudnn_conv_desc, cudnn_input_grad_desc, data_algo, &tmp_size));
|
|
|
|
|
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
|
|
|
|
|
// choose backward algorithm for data
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
|
|
|
|
|
handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc,
|
|
|
|
|
cudnn_input_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &data_algo));
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
|
|
|
|
|
handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc,
|
|
|
|
|
cudnn_input_desc, data_algo, &fwd_ws_size));
|
|
|
|
|
workspace_size_in_bytes = std::max(workspace_size_in_bytes, fwd_ws_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (filter_grad) {
|
|
|
|
|
cudnn_filter_grad_desc = filter_grad_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(filter_grad->dims()), groups);
|
|
|
|
|
// choose backward algorithm for filter
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
|
|
|
|
|
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
|
|
|
|
|
handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc,
|
|
|
|
|
cudnn_filter_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &filter_algo));
|
|
|
|
|
|
|
|
|
|
// get workspace for backwards filter algorithm
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
|
|
|
|
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
|
|
|
|
|
cudnn_filter_desc, filter_algo, &tmp_size));
|
|
|
|
|
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
|
|
|
|
|
handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc,
|
|
|
|
|
cudnn_filter_desc, filter_algo, &bwd_filter_ws_size));
|
|
|
|
|
workspace_size_in_bytes =
|
|
|
|
|
std::max(workspace_size_in_bytes, bwd_filter_ws_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn conv workspace ---------------------
|
|
|
|
|
// Already on GPU
|
|
|
|
|
void* cudnn_workspace = nullptr;
|
|
|
|
@ -235,35 +204,30 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto t = framework::EigenVector<T>::Flatten(*input_grad);
|
|
|
|
|
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
|
|
|
|
|
t.constant(static_cast<T>(0));
|
|
|
|
|
for (int i = 0; i < groups; i++) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
|
|
|
|
|
handle, &alpha, cudnn_filter_desc,
|
|
|
|
|
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
|
|
|
|
|
output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
|
|
|
|
|
cudnn_workspace, workspace_size_in_bytes, &beta,
|
|
|
|
|
cudnn_input_grad_desc, input_grad_data + i * group_offset_in));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
|
|
|
|
|
handle, &alpha, cudnn_output_desc, output_grad_data,
|
|
|
|
|
cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo,
|
|
|
|
|
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
|
|
|
|
|
input_grad_data));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn conv backward filter ---------------------
|
|
|
|
|
if (filter_grad) {
|
|
|
|
|
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto t = framework::EigenVector<T>::Flatten(*filter_grad);
|
|
|
|
|
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
|
|
|
|
|
t.constant(static_cast<T>(0));
|
|
|
|
|
for (int i = 0; i < groups; i++) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
|
|
|
|
|
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
|
|
|
|
|
cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
|
|
|
|
|
cudnn_conv_desc, filter_algo, cudnn_workspace,
|
|
|
|
|
workspace_size_in_bytes, &beta, cudnn_filter_grad_desc,
|
|
|
|
|
filter_grad_data + i * group_offset_filter));
|
|
|
|
|
}
|
|
|
|
|
// Gradient with respect to the filter
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
|
|
|
|
|
handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc,
|
|
|
|
|
input_data, cudnn_conv_desc, filter_algo, cudnn_workspace,
|
|
|
|
|
workspace_size_in_bytes, &beta, cudnn_filter_desc, filter_grad_data));
|
|
|
|
|
}
|
|
|
|
|
// Release the cudnn workspace
|
|
|
|
|
paddle::memory::Free(gpu, cudnn_workspace);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
@ -272,5 +236,5 @@ namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn,
|
|
|
|
|
ops::CudnnConvTransposeOpKernel<float>);
|
|
|
|
|
// REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn_grad,
|
|
|
|
|
// ops::CudnnConvTransposeGradOpKernel<float>);
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn_grad,
|
|
|
|
|
ops::CudnnConvTransposeGradOpKernel<float>);
|
|
|
|
|