|
|
|
@ -155,11 +155,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto chosen_memory_format =
|
|
|
|
|
platform::data_format_to_memory_format(data_format);
|
|
|
|
|
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
chosen_memory_format =
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
|
|
|
|
|
weights_format = mkldnn::memory::format::any;
|
|
|
|
|
// Check the format for user's special output
|
|
|
|
|
if (chosen_memory_format != mkldnn::memory::format::any) {
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
chosen_memory_format =
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
|
|
|
|
|
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
@ -435,11 +438,14 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto chosen_memory_format =
|
|
|
|
|
platform::data_format_to_memory_format(data_format);
|
|
|
|
|
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
chosen_memory_format =
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
|
|
|
|
|
weights_format = mkldnn::memory::format::any;
|
|
|
|
|
// Check the format for user's special output
|
|
|
|
|
if (chosen_memory_format != mkldnn::memory::format::any) {
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
chosen_memory_format =
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
|
|
|
|
|
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|