|
|
|
@ -52,7 +52,13 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
ScopedTensorDescriptor input_desc;
|
|
|
|
|
ScopedTensorDescriptor output_desc;
|
|
|
|
|
ScopedPoolingDescriptor pool_desc;
|
|
|
|
|
DataLayout layout = DataLayout::kNCHW;
|
|
|
|
|
DataLayout layout;
|
|
|
|
|
|
|
|
|
|
if (strides.size() == 2U) {
|
|
|
|
|
layout = DataLayout::kNCHW;
|
|
|
|
|
} else {
|
|
|
|
|
layout = DataLayout::kNCDHW;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(input->dims()));
|
|
|
|
@ -112,7 +118,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
ScopedTensorDescriptor input_desc;
|
|
|
|
|
ScopedTensorDescriptor output_desc;
|
|
|
|
|
ScopedPoolingDescriptor pool_desc;
|
|
|
|
|
DataLayout layout = DataLayout::kNCHW;
|
|
|
|
|
DataLayout layout;
|
|
|
|
|
|
|
|
|
|
if (strides.size() == 2U) {
|
|
|
|
|
layout = DataLayout::kNCHW;
|
|
|
|
|
} else {
|
|
|
|
|
layout = DataLayout::kNCDHW;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(input->dims()));
|
|
|
|
|