|
|
@ -240,7 +240,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
|
|
|
|
auto layout_format = GetCudnnTensorFormat(layout);
|
|
|
|
auto layout_format = GetCudnnTensorFormat(layout);
|
|
|
|
|
|
|
|
|
|
|
|
args.handle = handle;
|
|
|
|
args.handle = handle;
|
|
|
|
args.cdesc.set(dtype, padding_common, strides, dilations);
|
|
|
|
args.cdesc.set(dtype, padding_common, strides, dilations,
|
|
|
|
|
|
|
|
platform::AllowTF32Cudnn());
|
|
|
|
|
|
|
|
|
|
|
|
#if CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
#if CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
// cudnn 7 can support groups, no need to do it manually
|
|
|
|
// cudnn 7 can support groups, no need to do it manually
|
|
|
@ -603,7 +604,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
args1.idesc.set(transformed_input_grad, layout_tensor);
|
|
|
|
args1.idesc.set(transformed_input_grad, layout_tensor);
|
|
|
|
args1.wdesc.set(transformed_filter_channel, layout_tensor, iwo_groups);
|
|
|
|
args1.wdesc.set(transformed_filter_channel, layout_tensor, iwo_groups);
|
|
|
|
args1.odesc.set(transformed_output_grad_channel, layout_tensor);
|
|
|
|
args1.odesc.set(transformed_output_grad_channel, layout_tensor);
|
|
|
|
args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
|
|
|
|
args1.cdesc.set(dtype, padding_common, strides, dilations,
|
|
|
|
|
|
|
|
platform::AllowTF32Cudnn(), c_groups);
|
|
|
|
|
|
|
|
|
|
|
|
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
|
|
|
|
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
|
|
|
|
data_algo =
|
|
|
|
data_algo =
|
|
|
@ -620,7 +622,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
args2.wdesc.set(transformed_filter_grad_channel, layout_tensor,
|
|
|
|
args2.wdesc.set(transformed_filter_grad_channel, layout_tensor,
|
|
|
|
iwo_groups);
|
|
|
|
iwo_groups);
|
|
|
|
args2.odesc.set(transformed_output_grad_channel, layout_tensor);
|
|
|
|
args2.odesc.set(transformed_output_grad_channel, layout_tensor);
|
|
|
|
args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
|
|
|
|
args2.cdesc.set(dtype, padding_common, strides, dilations,
|
|
|
|
|
|
|
|
platform::AllowTF32Cudnn(), c_groups);
|
|
|
|
|
|
|
|
|
|
|
|
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
|
|
|
|
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
|
|
|
|
filter_algo =
|
|
|
|
filter_algo =
|
|
|
@ -980,7 +983,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
args1.idesc.set(transformed_ddX, iwo_group);
|
|
|
|
args1.idesc.set(transformed_ddX, iwo_group);
|
|
|
|
args1.wdesc.set(*W, layout, iwo_group);
|
|
|
|
args1.wdesc.set(*W, layout, iwo_group);
|
|
|
|
args1.odesc.set(transformed_ddO_channel, iwo_group);
|
|
|
|
args1.odesc.set(transformed_ddO_channel, iwo_group);
|
|
|
|
args1.cdesc.set(dtype, padding_common, strides, dilations, c_group);
|
|
|
|
args1.cdesc.set(dtype, padding_common, strides, dilations,
|
|
|
|
|
|
|
|
platform::AllowTF32Cudnn(), c_group);
|
|
|
|
|
|
|
|
|
|
|
|
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
|
|
|
|
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
|
|
|
|
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
|
|
|
|
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
|
|
|
@ -995,7 +999,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
args2.wdesc.set(*ddW, layout, iwo_group);
|
|
|
|
args2.wdesc.set(*ddW, layout, iwo_group);
|
|
|
|
|
|
|
|
|
|
|
|
args2.odesc.set(transformed_ddO_channel, iwo_group);
|
|
|
|
args2.odesc.set(transformed_ddO_channel, iwo_group);
|
|
|
|
args2.cdesc.set(dtype, padding_common, strides, dilations, c_group);
|
|
|
|
args2.cdesc.set(dtype, padding_common, strides, dilations,
|
|
|
|
|
|
|
|
platform::AllowTF32Cudnn(), c_group);
|
|
|
|
|
|
|
|
|
|
|
|
using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
|
|
|
|
using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
|
|
|
|
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
|
|
|
|
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
|
|
|
@ -1012,7 +1017,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
|
|
args3.odesc.set(transformed_dO_channel, iwo_group);
|
|
|
|
args3.odesc.set(transformed_dO_channel, iwo_group);
|
|
|
|
|
|
|
|
|
|
|
|
args3.cdesc.set(dtype, padding_common, strides, dilations, c_group);
|
|
|
|
args3.cdesc.set(dtype, padding_common, strides, dilations,
|
|
|
|
|
|
|
|
platform::AllowTF32Cudnn(), c_group);
|
|
|
|
|
|
|
|
|
|
|
|
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
|
|
|
|
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
|
|
|
|
filter_algo =
|
|
|
|
filter_algo =
|
|
|
@ -1028,7 +1034,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
args4.idesc.set(transformed_dX, iwo_group);
|
|
|
|
args4.idesc.set(transformed_dX, iwo_group);
|
|
|
|
args4.wdesc.set(*ddW, layout, iwo_group);
|
|
|
|
args4.wdesc.set(*ddW, layout, iwo_group);
|
|
|
|
args4.odesc.set(transformed_dO_channel, iwo_group);
|
|
|
|
args4.odesc.set(transformed_dO_channel, iwo_group);
|
|
|
|
args4.cdesc.set(dtype, padding_common, strides, dilations, c_group);
|
|
|
|
args4.cdesc.set(dtype, padding_common, strides, dilations,
|
|
|
|
|
|
|
|
platform::AllowTF32Cudnn(), c_group);
|
|
|
|
|
|
|
|
|
|
|
|
using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
|
|
|
|
using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
|
|
|
|
data_algo =
|
|
|
|
data_algo =
|
|
|
|