|
|
|
@ -280,12 +280,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
* ('any') which lets a primitive (convolution in this case) choose
|
|
|
|
|
* the memory format preferred for best performance
|
|
|
|
|
*/
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
auto chosen_memory_format =
|
|
|
|
|
platform::data_format_to_memory_format(data_format);
|
|
|
|
|
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
auto weights_md = platform::MKLDNNMemDesc(
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
auto dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|
dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
|
|
|
|
|
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
|
|
|
|
|
// create a conv primitive descriptor and save it for usage in backward
|
|
|
|
|
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
|
|
|
|
@ -423,16 +427,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
* ('any') which lets a primitive (conv backward in this case) choose
|
|
|
|
|
* the memory format preferred for best performance
|
|
|
|
|
*/
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
auto chosen_memory_format =
|
|
|
|
|
platform::data_format_to_memory_format(data_format);
|
|
|
|
|
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
auto diff_src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
auto weights_md = platform::MKLDNNMemDesc(
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
auto diff_weights_md = platform::MKLDNNMemDesc(
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
auto diff_dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|
dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
|
|
|
|
|
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
|
|
|
|
|
// Retrieve conv_pd from device context
|
|
|
|
|
auto conv_pd =
|
|
|
|
|