|
|
|
@ -50,12 +50,18 @@ static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache";
|
|
|
|
|
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
|
|
|
|
|
static_cast<size_t>(1024) * 1024 * 1024;
|
|
|
|
|
|
|
|
|
|
static constexpr size_t kNUM_CUDNN_FWD_ALGS =
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
|
|
|
|
#if CUDNN_VERSION_MIN(6, 0, 5)
|
|
|
|
|
static constexpr size_t kNUM_CUDNN_FWD_ALGS = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
|
|
|
|
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS =
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
|
|
|
|
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS =
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
|
|
|
|
#else
|
|
|
|
|
// cuDNN v5 has no CUDNN_CONVOLUTION_FWD_ALGO_COUNT etc.
|
|
|
|
|
static constexpr size_t kNUM_CUDNN_FWD_ALGS = 7;
|
|
|
|
|
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4;
|
|
|
|
|
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class CUDNNConvOpKernel : public framework::OpKernel<T> {
|
|
|
|
|