fix data layout

release/0.11.0
chengduoZH 7 years ago
parent e825a49259
commit 7c2fd61869

@ -52,7 +52,13 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
DataLayout layout;
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
@ -112,7 +118,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
DataLayout layout;
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));

Loading…
Cancel
Save