|
|
|
@ -201,7 +201,8 @@ void hl_conv_workspace(hl_tensor_descriptor input,
|
|
|
|
|
int* convBwdDataAlgo,
|
|
|
|
|
size_t* bwdDataLimitBytes,
|
|
|
|
|
int* convBwdFilterAlgo,
|
|
|
|
|
size_t* bwdFilterLimitBytes) {
|
|
|
|
|
size_t* bwdFilterLimitBytes,
|
|
|
|
|
bool useDilation) {
|
|
|
|
|
#if CUDNN_VERSION >= 4000
|
|
|
|
|
|
|
|
|
|
CHECK_NOTNULL(input);
|
|
|
|
@ -213,21 +214,60 @@ void hl_conv_workspace(hl_tensor_descriptor input,
|
|
|
|
|
size_t memoryLimitBytes =
|
|
|
|
|
(1LL << 20) * FLAGS_cudnn_conv_workspace_limit_in_mb;
|
|
|
|
|
|
|
|
|
|
// For dilation
|
|
|
|
|
int algo = 0;
|
|
|
|
|
|
|
|
|
|
// cudnn convolution forward configuration
|
|
|
|
|
cudnnTensorDescriptor_t fwd_src_desc = GET_TENSOR_DESCRIPTOR(input);
|
|
|
|
|
cudnnTensorDescriptor_t fwd_dest_desc = GET_TENSOR_DESCRIPTOR(output);
|
|
|
|
|
cudnnFilterDescriptor_t fwd_filter_desc = GET_FILTER_DESCRIPTOR(filter);
|
|
|
|
|
cudnnConvolutionDescriptor_t fwd_conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);
|
|
|
|
|
// cudnn convolution backward data configuration
|
|
|
|
|
cudnnFilterDescriptor_t bwd_data_filter_desc = GET_FILTER_DESCRIPTOR(filter);
|
|
|
|
|
cudnnTensorDescriptor_t bwd_data_diff_desc = GET_TENSOR_DESCRIPTOR(output);
|
|
|
|
|
cudnnTensorDescriptor_t bwd_data_grad_desc = GET_TENSOR_DESCRIPTOR(input);
|
|
|
|
|
cudnnConvolutionDescriptor_t bwd_data_conv_desc =
|
|
|
|
|
GET_CONVOLUTION_DESCRIPTOR(conv);
|
|
|
|
|
// cudnn convolution backward filter configuration
|
|
|
|
|
cudnnTensorDescriptor_t bwd_filter_src_desc = GET_TENSOR_DESCRIPTOR(input);
|
|
|
|
|
cudnnTensorDescriptor_t bwd_filter_diff_desc = GET_TENSOR_DESCRIPTOR(output);
|
|
|
|
|
cudnnConvolutionDescriptor_t bwd_filter_conv_desc =
|
|
|
|
|
GET_CONVOLUTION_DESCRIPTOR(conv);
|
|
|
|
|
cudnnFilterDescriptor_t bwd_filter_grad_desc = GET_FILTER_DESCRIPTOR(filter);
|
|
|
|
|
|
|
|
|
|
CHECK_CUDNN(dynload::cudnnGetConvolutionForwardAlgorithm(
|
|
|
|
|
t_resource.cudnn_handle,
|
|
|
|
|
fwd_src_desc,
|
|
|
|
|
fwd_filter_desc,
|
|
|
|
|
fwd_conv_desc,
|
|
|
|
|
fwd_dest_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
memoryLimitBytes,
|
|
|
|
|
reinterpret_cast<cudnnConvolutionFwdAlgo_t*>(convFwdAlgo)));
|
|
|
|
|
if (useDilation) {
|
|
|
|
|
convFwdAlgo = &algo;
|
|
|
|
|
convBwdDataAlgo = &algo;
|
|
|
|
|
convBwdFilterAlgo = &algo;
|
|
|
|
|
} else {
|
|
|
|
|
CHECK_CUDNN(dynload::cudnnGetConvolutionForwardAlgorithm(
|
|
|
|
|
t_resource.cudnn_handle,
|
|
|
|
|
fwd_src_desc,
|
|
|
|
|
fwd_filter_desc,
|
|
|
|
|
fwd_conv_desc,
|
|
|
|
|
fwd_dest_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
memoryLimitBytes,
|
|
|
|
|
reinterpret_cast<cudnnConvolutionFwdAlgo_t*>(convFwdAlgo)));
|
|
|
|
|
CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
|
t_resource.cudnn_handle,
|
|
|
|
|
bwd_data_filter_desc,
|
|
|
|
|
bwd_data_diff_desc,
|
|
|
|
|
bwd_data_conv_desc,
|
|
|
|
|
bwd_data_grad_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
memoryLimitBytes,
|
|
|
|
|
reinterpret_cast<cudnnConvolutionBwdDataAlgo_t*>(convBwdDataAlgo)));
|
|
|
|
|
CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
|
|
|
|
|
t_resource.cudnn_handle,
|
|
|
|
|
bwd_filter_src_desc,
|
|
|
|
|
bwd_filter_diff_desc,
|
|
|
|
|
bwd_filter_conv_desc,
|
|
|
|
|
bwd_filter_grad_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
memoryLimitBytes,
|
|
|
|
|
reinterpret_cast<cudnnConvolutionBwdFilterAlgo_t*>(convBwdFilterAlgo)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CHECK_CUDNN(dynload::cudnnGetConvolutionForwardWorkspaceSize(
|
|
|
|
|
t_resource.cudnn_handle,
|
|
|
|
@ -238,23 +278,6 @@ void hl_conv_workspace(hl_tensor_descriptor input,
|
|
|
|
|
static_cast<cudnnConvolutionFwdAlgo_t>(*convFwdAlgo),
|
|
|
|
|
fwdLimitBytes));
|
|
|
|
|
|
|
|
|
|
// cudnn convolution backward data configuration
|
|
|
|
|
cudnnFilterDescriptor_t bwd_data_filter_desc = GET_FILTER_DESCRIPTOR(filter);
|
|
|
|
|
cudnnTensorDescriptor_t bwd_data_diff_desc = GET_TENSOR_DESCRIPTOR(output);
|
|
|
|
|
cudnnTensorDescriptor_t bwd_data_grad_desc = GET_TENSOR_DESCRIPTOR(input);
|
|
|
|
|
cudnnConvolutionDescriptor_t bwd_data_conv_desc =
|
|
|
|
|
GET_CONVOLUTION_DESCRIPTOR(conv);
|
|
|
|
|
|
|
|
|
|
CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
|
t_resource.cudnn_handle,
|
|
|
|
|
bwd_data_filter_desc,
|
|
|
|
|
bwd_data_diff_desc,
|
|
|
|
|
bwd_data_conv_desc,
|
|
|
|
|
bwd_data_grad_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
memoryLimitBytes,
|
|
|
|
|
reinterpret_cast<cudnnConvolutionBwdDataAlgo_t*>(convBwdDataAlgo)));
|
|
|
|
|
|
|
|
|
|
CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
|
|
|
|
|
t_resource.cudnn_handle,
|
|
|
|
|
bwd_data_filter_desc,
|
|
|
|
@ -264,23 +287,6 @@ void hl_conv_workspace(hl_tensor_descriptor input,
|
|
|
|
|
static_cast<cudnnConvolutionBwdDataAlgo_t>(*convBwdDataAlgo),
|
|
|
|
|
bwdDataLimitBytes));
|
|
|
|
|
|
|
|
|
|
// cudnn convolution backward filter configuration
|
|
|
|
|
cudnnTensorDescriptor_t bwd_filter_src_desc = GET_TENSOR_DESCRIPTOR(input);
|
|
|
|
|
cudnnTensorDescriptor_t bwd_filter_diff_desc = GET_TENSOR_DESCRIPTOR(output);
|
|
|
|
|
cudnnConvolutionDescriptor_t bwd_filter_conv_desc =
|
|
|
|
|
GET_CONVOLUTION_DESCRIPTOR(conv);
|
|
|
|
|
cudnnFilterDescriptor_t bwd_filter_grad_desc = GET_FILTER_DESCRIPTOR(filter);
|
|
|
|
|
|
|
|
|
|
CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
|
|
|
|
|
t_resource.cudnn_handle,
|
|
|
|
|
bwd_filter_src_desc,
|
|
|
|
|
bwd_filter_diff_desc,
|
|
|
|
|
bwd_filter_conv_desc,
|
|
|
|
|
bwd_filter_grad_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
memoryLimitBytes,
|
|
|
|
|
reinterpret_cast<cudnnConvolutionBwdFilterAlgo_t*>(convBwdFilterAlgo)));
|
|
|
|
|
|
|
|
|
|
CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
|
|
|
|
t_resource.cudnn_handle,
|
|
|
|
|
bwd_filter_src_desc,
|
|
|
|
@ -603,7 +609,9 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
|
|
|
|
|
int padding_height,
|
|
|
|
|
int padding_width,
|
|
|
|
|
int stride_height,
|
|
|
|
|
int stride_width) {
|
|
|
|
|
int stride_width,
|
|
|
|
|
int dilation_h,
|
|
|
|
|
int dilation_w) {
|
|
|
|
|
CHECK_NOTNULL(conv);
|
|
|
|
|
|
|
|
|
|
cudnn_convolution_descriptor hl_conv = (cudnn_convolution_descriptor)malloc(
|
|
|
|
@ -625,18 +633,23 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
|
|
|
|
|
padding_width,
|
|
|
|
|
stride_height,
|
|
|
|
|
stride_width,
|
|
|
|
|
1,
|
|
|
|
|
1,
|
|
|
|
|
dilation_h,
|
|
|
|
|
dilation_w,
|
|
|
|
|
mode,
|
|
|
|
|
data_type));
|
|
|
|
|
#else
|
|
|
|
|
if (dilation_h > 1 || dilation_w > 1) {
|
|
|
|
|
LOG(FATAL)
|
|
|
|
|
<< "Current cudnn version does't support for dilation convolution.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc,
|
|
|
|
|
padding_height,
|
|
|
|
|
padding_width,
|
|
|
|
|
stride_height,
|
|
|
|
|
stride_width,
|
|
|
|
|
1,
|
|
|
|
|
1,
|
|
|
|
|
dilation_h,
|
|
|
|
|
dilation_w,
|
|
|
|
|
mode));
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
@ -659,7 +672,9 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
|
|
|
|
|
int padding_height,
|
|
|
|
|
int padding_width,
|
|
|
|
|
int stride_height,
|
|
|
|
|
int stride_width) {
|
|
|
|
|
int stride_width,
|
|
|
|
|
int dilation_h,
|
|
|
|
|
int dilation_w) {
|
|
|
|
|
CHECK_NOTNULL(conv);
|
|
|
|
|
CHECK_NOTNULL(image);
|
|
|
|
|
CHECK_NOTNULL(filter);
|
|
|
|
@ -678,8 +693,8 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
|
|
|
|
|
padding_width,
|
|
|
|
|
stride_height,
|
|
|
|
|
stride_width,
|
|
|
|
|
1,
|
|
|
|
|
1,
|
|
|
|
|
dilation_h,
|
|
|
|
|
dilation_w,
|
|
|
|
|
mode,
|
|
|
|
|
data_type));
|
|
|
|
|
#else
|
|
|
|
@ -688,8 +703,8 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
|
|
|
|
|
padding_width,
|
|
|
|
|
stride_height,
|
|
|
|
|
stride_width,
|
|
|
|
|
1,
|
|
|
|
|
1,
|
|
|
|
|
dilation_h,
|
|
|
|
|
dilation_w,
|
|
|
|
|
mode));
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|