|
|
|
@ -265,6 +265,16 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
algo = search::Find<T>(args, exhaustive_search, false, 0, ctx);
|
|
|
|
|
workspace_size = search::GetWorkspaceSize(args, algo);
|
|
|
|
|
|
|
|
|
|
#if CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
|
// when groups > 1, SearchAlgorithm find algo is CUDNN_CONVOLUTION_\
|
|
|
|
|
// FWD_ALGO_WINOGRAD_NONFUSED, but this kind of algorithm is unstable
|
|
|
|
|
// in forward computation, so change the algorithm to CUDNN_CONVOLUTION_\
|
|
|
|
|
// FWD_ALGO_IMPLICIT_GEMM manually.
|
|
|
|
|
if (ctx.Attr<int>("groups") > 1) {
|
|
|
|
|
algo = static_cast<cudnnConvolutionFwdAlgo_t>(0);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn conv forward ---------------------
|
|
|
|
|
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
|
|
|
|
|
for (int i = 0; i < groups; i++) {
|
|
|
|
@ -805,6 +815,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
#if CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
|
iwo_group = 1;
|
|
|
|
|
c_group = groups;
|
|
|
|
|
groups = 1;
|
|
|
|
|
#endif
|
|
|
|
|
auto dtype = platform::CudnnDataType<T>::type;
|
|
|
|
|
|
|
|
|
|