|
|
|
@ -775,8 +775,23 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
* ('any') which lets a primitive (conv backward in this case) choose
|
|
|
|
|
* the memory format preferred for best performance
|
|
|
|
|
*/
|
|
|
|
|
auto chosen_memory_format = MKLDNNMemoryFormat::any;
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): Once GRAD NHWC is working then format 'any'
|
|
|
|
|
// should be used exclusively. But till forward pass enforce
|
|
|
|
|
// NCHW for training we need to have NCHW here as well
|
|
|
|
|
// to avoid performance degradation in relu_grad and pool2d_grad
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
auto chosen_memory_format =
|
|
|
|
|
platform::data_format_to_memory_format(data_format);
|
|
|
|
|
|
|
|
|
|
weights_format = MKLDNNMemoryFormat::any;
|
|
|
|
|
// Check the format for user's special output
|
|
|
|
|
if (chosen_memory_format != MKLDNNMemoryFormat::any) {
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
chosen_memory_format =
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|