|
|
|
@ -28,7 +28,8 @@ using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
|
|
|
|
|
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
|
|
|
|
|
using DataLayout = platform::DataLayout;
|
|
|
|
|
|
|
|
|
|
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024;
|
|
|
|
|
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
|
|
|
|
|
static_cast<size_t>(1024) * 1024 * 1024;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class CudnnConvOpKernel : public framework::OpKernel<T> {
|
|
|
|
@ -44,7 +45,8 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
|
|
|
|
|
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
|
|
|
|
|
int groups = ctx.Attr<int>("groups");
|
|
|
|
|
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
|
|
|
|
|
int64_t user_workspace_size =
|
|
|
|
|
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
|
|
|
|
|
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
const T* filter_data = filter->data<T>();
|
|
|
|
@ -163,7 +165,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
|
|
|
|
|
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
|
|
|
|
|
int groups = ctx.Attr<int>("groups");
|
|
|
|
|
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
|
|
|
|
|
int64_t user_workspace_size =
|
|
|
|
|
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn descriptors ---------------------
|
|
|
|
|
ScopedTensorDescriptor input_desc;
|
|
|
|
|