fix the bug of conv_transpose cudnn kernel, test=develop (#20958)

fix the bug of conv_transpose cudnn kernel: before version 1.6, the data_format is AnyLayout in inference model. When use version 1.6 and load the model which is saved by previous version, the error occurs.  This is because the cudnn kernel in version 1.6 is not compitable with Anylayout setting.
custom_op_abi
Zhang Ting 6 years ago committed by lanxianghit
parent 7695b713e1
commit f4f85831d3

@ -72,7 +72,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
const T* filter_data = filter->data<T>();
const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const paddle::operators::DataLayout data_layout =
(data_layout_str == "NCHW" ? DataLayout::kNCHW : DataLayout::kNHWC);
(data_layout_str != "NHWC" ? DataLayout::kNCHW : DataLayout::kNHWC);
// if channel_last, transpose to channel_first
Tensor input_transpose;

Loading…
Cancel
Save