|
|
|
@ -206,17 +206,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
|
|
|
|
|
|
|
|
|
|
// create mkldnn memory from input x tensor
|
|
|
|
|
mkldnn::memory::format input_format =
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
|
|
|
|
|
|
|
|
|
|
// keys for backward pass
|
|
|
|
|
const std::string key = BatchNormMKLDNNHandler::GetHash(
|
|
|
|
|
src_tz, epsilon, flags, global_stats, input_format,
|
|
|
|
|
src_tz, epsilon, flags, global_stats, x->format(),
|
|
|
|
|
ctx.op().Output("SavedMean"));
|
|
|
|
|
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
|
|
|
|
|
|
|
|
|
|
auto user_src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
{src_tz}, platform::MKLDNNGetDataType<T>(), input_format);
|
|
|
|
|
auto user_src_md = x->get_mkldnn_prim_desc().desc();
|
|
|
|
|
|
|
|
|
|
// create primitive descriptor for batch norm forward
|
|
|
|
|
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
|
|
|
|
@ -230,8 +227,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine,
|
|
|
|
|
key);
|
|
|
|
|
|
|
|
|
|
auto src_memory =
|
|
|
|
|
handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data));
|
|
|
|
|
auto src_memory = handler.AcquireSrcMemory(x->get_mkldnn_prim_desc(),
|
|
|
|
|
to_void_cast(x_data));
|
|
|
|
|
|
|
|
|
|
// crate mkldnn memory for weights(scale/shift)
|
|
|
|
|
auto scaleshift_memory =
|
|
|
|
@ -265,8 +262,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
variance_memory, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
y->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
|
|
|
|
|
y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc());
|
|
|
|
|
|
|
|
|
|
std::vector<mkldnn::primitive> pipeline;
|
|
|
|
|
pipeline.push_back(*batch_norm_p);
|
|
|
|
@ -336,9 +332,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
|
|
|
|
|
|
|
|
|
|
mkldnn::memory::format dst_format =
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
|
|
|
|
|
|
|
|
|
|
mkldnn::memory::format input_format =
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
|
|
|
|
|
|
|
|
|
@ -346,14 +339,14 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// keys from forward pass
|
|
|
|
|
const std::string key = BatchNormMKLDNNHandler::GetHash(
|
|
|
|
|
src_tz, epsilon, flags, false, input_format,
|
|
|
|
|
src_tz, epsilon, flags, false, x->format(),
|
|
|
|
|
ctx.op().Input("SavedMean"));
|
|
|
|
|
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
|
|
|
|
|
|
|
|
|
|
// keys for primitives reuse
|
|
|
|
|
const std::string key_with_hash =
|
|
|
|
|
key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false,
|
|
|
|
|
input_format);
|
|
|
|
|
x->format());
|
|
|
|
|
const std::string key_batch_norm_bwd_p =
|
|
|
|
|
key_with_hash + "@batch_norm_bwd_p";
|
|
|
|
|
const std::string key_batch_norm_src_mem_p =
|
|
|
|
@ -373,9 +366,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
primitive reorder_diff_dst;
|
|
|
|
|
bool is_diff_dst_reordered = false;
|
|
|
|
|
auto user_diff_dst_memory = memory(
|
|
|
|
|
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
|
|
|
|
|
to_void_cast(diff_y_data));
|
|
|
|
|
auto user_diff_dst_memory =
|
|
|
|
|
memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data));
|
|
|
|
|
|
|
|
|
|
// MKLDNN requires a single piece of memory for scale and shift/bias data
|
|
|
|
|
const size_t scaleshift_size = 2 * ic;
|
|
|
|
@ -459,10 +451,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory);
|
|
|
|
|
|
|
|
|
|
// set layout/format of output tensors
|
|
|
|
|
diff_x->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
|
|
|
|
|
.desc()
|
|
|
|
|
.data.format);
|
|
|
|
|
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
|
|
|
|
|
} else {
|
|
|
|
|
// primitives already exist
|
|
|
|
|
UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data));
|
|
|
|
@ -487,10 +476,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// set layout/format of output tensors
|
|
|
|
|
diff_x->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
|
|
|
|
|
.desc()
|
|
|
|
|
.data.format);
|
|
|
|
|
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// execute optional reorder and batch_norm backward primitive
|
|
|
|
|