|
|
@ -22,6 +22,7 @@ DECLARE_bool(cudnn_exhaustive_search);
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if CUDNN_VERSION >= 7001
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
|
|
|
|
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
|
|
|
|
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
|
|
|
|
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
|
|
|
@ -178,10 +179,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
|
|
|
|
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
|
|
|
|
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if CUDNN_VERSION >= 7001
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>,
|
|
|
|
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>,
|
|
|
|
ops::CUDNNConvFusionOpKernel<double>);
|
|
|
|
ops::CUDNNConvFusionOpKernel<double>);
|
|
|
|
|
|
|
|
#endif
|
|
|
|