|
|
|
@ -115,9 +115,16 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
|
|
|
|
|
|
|
|
|
|
// create mkldnn memory from input x tensor
|
|
|
|
|
auto src_memory =
|
|
|
|
|
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine},
|
|
|
|
|
to_void_cast(x_data));
|
|
|
|
|
mkldnn::memory::format input_format = x->format();
|
|
|
|
|
if (src_tz.size() == 1) {
|
|
|
|
|
input_format = mkldnn::memory::format::x;
|
|
|
|
|
} else if (src_tz.size() == 2) {
|
|
|
|
|
input_format = mkldnn::memory::format::nc;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto src_memory = memory(
|
|
|
|
|
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
|
|
|
|
|
to_void_cast(x_data));
|
|
|
|
|
|
|
|
|
|
// create primitive descriptor for batch norm forward
|
|
|
|
|
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
|
|
|
|
@ -251,15 +258,28 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
|
|
|
|
|
|
|
|
|
|
// create mkldnn memory from input diff_y tensor
|
|
|
|
|
auto user_diff_dst_memory =
|
|
|
|
|
memory({{{diff_dst_tz}, memory::data_type::f32, diff_y->format()},
|
|
|
|
|
mkldnn_engine},
|
|
|
|
|
to_void_cast(diff_y_data));
|
|
|
|
|
|
|
|
|
|
mkldnn::memory::format dst_format = x->format();
|
|
|
|
|
if (diff_dst_tz.size() == 1) {
|
|
|
|
|
dst_format = mkldnn::memory::format::x;
|
|
|
|
|
} else if (diff_dst_tz.size() == 2) {
|
|
|
|
|
dst_format = mkldnn::memory::format::nc;
|
|
|
|
|
}
|
|
|
|
|
auto user_diff_dst_memory = memory(
|
|
|
|
|
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
|
|
|
|
|
to_void_cast(diff_y_data));
|
|
|
|
|
|
|
|
|
|
// create mkldnn memory from input x tensor
|
|
|
|
|
auto src_memory =
|
|
|
|
|
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine},
|
|
|
|
|
to_void_cast(x_data));
|
|
|
|
|
mkldnn::memory::format input_format = x->format();
|
|
|
|
|
if (src_tz.size() == 1) {
|
|
|
|
|
input_format = mkldnn::memory::format::x;
|
|
|
|
|
} else if (src_tz.size() == 2) {
|
|
|
|
|
input_format = mkldnn::memory::format::nc;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto src_memory = memory(
|
|
|
|
|
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
|
|
|
|
|
to_void_cast(x_data));
|
|
|
|
|
|
|
|
|
|
// for diff_dst, try to use same format as dst in forward pass
|
|
|
|
|
auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc();
|
|
|
|
|