|
|
|
@ -31,9 +31,9 @@ class BatchNormMKLDNNHandler
|
|
|
|
|
: public platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
|
|
|
|
|
mkldnn::batch_normalization_backward> {
|
|
|
|
|
public:
|
|
|
|
|
BatchNormMKLDNNHandler(const std::vector<int> &dims, const float &epsilon,
|
|
|
|
|
const unsigned &flags, const bool &global_stats,
|
|
|
|
|
const MKLDNNMemoryFormat fmt,
|
|
|
|
|
BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon,
|
|
|
|
|
const mkldnn::normalization_flags &flags,
|
|
|
|
|
const bool &global_stats, const MKLDNNMemoryFormat fmt,
|
|
|
|
|
const platform::MKLDNNDeviceContext &dev_ctx,
|
|
|
|
|
platform::Place cpu_place,
|
|
|
|
|
const std::string &uniq_name)
|
|
|
|
@ -48,8 +48,8 @@ class BatchNormMKLDNNHandler
|
|
|
|
|
: mkldnn::prop_kind::forward_training,
|
|
|
|
|
md, epsilon, flags);
|
|
|
|
|
}
|
|
|
|
|
BatchNormMKLDNNHandler(const std::vector<int> &dims, const float &epsilon,
|
|
|
|
|
const unsigned &flags,
|
|
|
|
|
BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon,
|
|
|
|
|
const mkldnn::normalization_flags &flags,
|
|
|
|
|
const MKLDNNMemoryFormat diff_fmt,
|
|
|
|
|
const MKLDNNMemoryFormat src_fmt,
|
|
|
|
|
const platform::MKLDNNDeviceContext &dev_ctx,
|
|
|
|
@ -70,14 +70,13 @@ class BatchNormMKLDNNHandler
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(T *scaleshift_data) {
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(
|
|
|
|
|
this->fwd_pd_->weights_primitive_desc(), scaleshift_data,
|
|
|
|
|
"@scaleshift_mem_p");
|
|
|
|
|
this->fwd_pd_->weights_desc(), scaleshift_data, "@scaleshift_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory(
|
|
|
|
|
T *diff_scaleshift_data) {
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(
|
|
|
|
|
this->bwd_pd_->diff_weights_primitive_desc(), diff_scaleshift_data,
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(),
|
|
|
|
|
diff_scaleshift_data,
|
|
|
|
|
"@diff_scaleshift_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -85,32 +84,30 @@ class BatchNormMKLDNNHandler
|
|
|
|
|
const framework::Tensor *mean) {
|
|
|
|
|
const T *mean_data = mean->data<T>();
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(
|
|
|
|
|
this->fwd_pd_->mean_primitive_desc(), to_void_cast<T>(mean_data),
|
|
|
|
|
"@mean_mem_p");
|
|
|
|
|
this->fwd_pd_->mean_desc(), to_void_cast<T>(mean_data), "@mean_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireMeanMemory(framework::Tensor *mean) {
|
|
|
|
|
T *mean_data = mean->mutable_data<T>(
|
|
|
|
|
this->place_, this->fwd_pd_->mean_primitive_desc().get_size());
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(
|
|
|
|
|
this->fwd_pd_->mean_primitive_desc(), mean_data, "@mean_mem_p");
|
|
|
|
|
T *mean_data = mean->mutable_data<T>(this->place_,
|
|
|
|
|
this->fwd_pd_->mean_desc().get_size());
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
|
|
|
|
|
mean_data, "@mean_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
|
|
|
|
|
const framework::Tensor *variance) {
|
|
|
|
|
const T *variance_data = variance->data<T>();
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(
|
|
|
|
|
this->fwd_pd_->variance_primitive_desc(),
|
|
|
|
|
to_void_cast<T>(variance_data), "@variance_mem_p");
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
|
|
|
|
|
to_void_cast<T>(variance_data),
|
|
|
|
|
"@variance_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
|
|
|
|
|
framework::Tensor *variance) {
|
|
|
|
|
T *variance_data = variance->mutable_data<T>(
|
|
|
|
|
this->place_, this->fwd_pd_->variance_primitive_desc().get_size());
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(
|
|
|
|
|
this->fwd_pd_->variance_primitive_desc(), variance_data,
|
|
|
|
|
"@variance_mem_p");
|
|
|
|
|
this->place_, this->fwd_pd_->variance_desc().get_size());
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
|
|
|
|
|
variance_data, "@variance_mem_p");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -140,11 +137,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
|
|
|
|
|
"Wrong layout set for X tensor");
|
|
|
|
|
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::format_undef,
|
|
|
|
|
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
"Wrong format set for X tensor");
|
|
|
|
|
|
|
|
|
|
auto src_tz = paddle::framework::vectorize<int>(x->dims());
|
|
|
|
|
auto scale_tz = paddle::framework::vectorize<int>(scale->dims());
|
|
|
|
|
auto src_tz = paddle::framework::vectorize<int64_t>(x->dims());
|
|
|
|
|
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
|
|
|
|
|
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
|
|
|
|
|
const unsigned int C = scale_tz[0];
|
|
|
|
|
|
|
|
|
@ -156,9 +153,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
shift->data<T>() + C);
|
|
|
|
|
|
|
|
|
|
// Flags are added by bitwise OR operation
|
|
|
|
|
unsigned flags = mkldnn::use_scale_shift; // 001
|
|
|
|
|
if (global_stats) flags |= mkldnn::use_global_stats; // 010
|
|
|
|
|
if (fuse_with_relu && is_test) flags |= mkldnn::fuse_bn_relu; // 100
|
|
|
|
|
auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
|
|
|
|
|
if (global_stats)
|
|
|
|
|
flags |= mkldnn::normalization_flags::use_global_stats; // 010
|
|
|
|
|
if (fuse_with_relu && is_test)
|
|
|
|
|
flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
|
|
|
|
|
|
|
|
|
|
BatchNormMKLDNNHandler<T> handler(
|
|
|
|
|
src_tz, epsilon, flags, global_stats,
|
|
|
|
@ -170,38 +169,35 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
handler.AcquireScaleShiftMemory(scaleshift_data.data());
|
|
|
|
|
auto dst_memory = handler.AcquireDstMemory(y);
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::batch_normalization_forward> batch_norm_p;
|
|
|
|
|
auto batch_norm_p = handler.AcquireForwardPrimitive();
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<memory> mean_memory;
|
|
|
|
|
std::shared_ptr<memory> variance_memory;
|
|
|
|
|
|
|
|
|
|
if (global_stats) {
|
|
|
|
|
// mean and variance are taken from input Tensor
|
|
|
|
|
const auto *mean = ctx.Input<Tensor>("Mean");
|
|
|
|
|
const auto *variance = ctx.Input<Tensor>("Variance");
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<memory> mean_memory = handler.AcquireMeanMemory(mean);
|
|
|
|
|
std::shared_ptr<memory> variance_memory =
|
|
|
|
|
handler.AcquireVarianceMemory(variance);
|
|
|
|
|
|
|
|
|
|
batch_norm_p = handler.AcquireForwardPrimitive(
|
|
|
|
|
*src_memory, (const mkldnn::primitive::at &)*mean_memory,
|
|
|
|
|
(const mkldnn::primitive::at &)*variance_memory, *scaleshift_memory,
|
|
|
|
|
*dst_memory);
|
|
|
|
|
mean_memory = handler.AcquireMeanMemory(mean);
|
|
|
|
|
variance_memory = handler.AcquireVarianceMemory(variance);
|
|
|
|
|
} else {
|
|
|
|
|
// mean and variance are calculated and saved in output Tensor
|
|
|
|
|
std::shared_ptr<memory> mean_memory =
|
|
|
|
|
handler.AcquireMeanMemory(batch_mean);
|
|
|
|
|
std::shared_ptr<memory> variance_memory =
|
|
|
|
|
handler.AcquireVarianceMemory(batch_variance);
|
|
|
|
|
|
|
|
|
|
batch_norm_p = handler.AcquireForwardPrimitive(
|
|
|
|
|
*src_memory, *scaleshift_memory, *dst_memory, *mean_memory,
|
|
|
|
|
*variance_memory);
|
|
|
|
|
mean_memory = handler.AcquireMeanMemory(batch_mean);
|
|
|
|
|
variance_memory = handler.AcquireVarianceMemory(batch_variance);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
y->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
|
|
|
|
|
|
|
|
|
|
std::vector<mkldnn::primitive> pipeline;
|
|
|
|
|
pipeline.push_back(*batch_norm_p);
|
|
|
|
|
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
mkldnn::stream astream(dev_ctx.GetEngine());
|
|
|
|
|
batch_norm_p->execute(astream,
|
|
|
|
|
{{MKLDNN_ARG_SRC, *src_memory},
|
|
|
|
|
{MKLDNN_ARG_SCALE_SHIFT, *scaleshift_memory},
|
|
|
|
|
{MKLDNN_ARG_MEAN, *mean_memory},
|
|
|
|
|
{MKLDNN_ARG_VARIANCE, *variance_memory},
|
|
|
|
|
{MKLDNN_ARG_DST, *dst_memory}});
|
|
|
|
|
astream.wait();
|
|
|
|
|
|
|
|
|
|
if (!global_stats) {
|
|
|
|
|
// mkldnn only compute stats for current batch
|
|
|
|
@ -245,11 +241,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
|
|
|
|
|
"Wrong layout set for Input diff_y tensor");
|
|
|
|
|
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::format_undef,
|
|
|
|
|
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
"Wrong format set for Input diff_y tensor");
|
|
|
|
|
|
|
|
|
|
auto src_tz = paddle::framework::vectorize<int>(x->dims());
|
|
|
|
|
auto scale_tz = paddle::framework::vectorize<int>(scale->dims());
|
|
|
|
|
auto src_tz = paddle::framework::vectorize<int64_t>(x->dims());
|
|
|
|
|
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
|
|
|
|
|
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
|
|
|
|
|
|
|
|
|
|
const unsigned int C = scale_tz[0];
|
|
|
|
@ -261,8 +257,9 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
|
|
|
|
|
|
|
|
|
|
BatchNormMKLDNNHandler<T> handler(
|
|
|
|
|
src_tz, epsilon, mkldnn::use_scale_shift, dst_format, input_format,
|
|
|
|
|
dev_ctx, ctx.GetPlace(), ctx.InputName("SavedMean"));
|
|
|
|
|
src_tz, epsilon, mkldnn::normalization_flags::use_scale_shift,
|
|
|
|
|
dst_format, input_format, dev_ctx, ctx.GetPlace(),
|
|
|
|
|
ctx.InputName("SavedMean"));
|
|
|
|
|
|
|
|
|
|
// MKLDNN requires a single piece of memory for scale and shift/bias data
|
|
|
|
|
const size_t scaleshift_size = 2 * C;
|
|
|
|
@ -285,13 +282,18 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());
|
|
|
|
|
|
|
|
|
|
// finally create batch_norm backward primitive
|
|
|
|
|
auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive(
|
|
|
|
|
*src_memory, *mean_memory, *variance_memory, *diff_dst_memory,
|
|
|
|
|
*scaleshift_memory, *diff_src_memory, *diff_scaleshift_memory);
|
|
|
|
|
|
|
|
|
|
std::vector<primitive> pipeline;
|
|
|
|
|
pipeline.push_back(*batch_norm_bwd_p);
|
|
|
|
|
stream(stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive();
|
|
|
|
|
|
|
|
|
|
mkldnn::stream astream(dev_ctx.GetEngine());
|
|
|
|
|
batch_norm_bwd_p->execute(
|
|
|
|
|
astream, {{MKLDNN_ARG_SRC, *src_memory},
|
|
|
|
|
{MKLDNN_ARG_MEAN, *mean_memory},
|
|
|
|
|
{MKLDNN_ARG_VARIANCE, *variance_memory},
|
|
|
|
|
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory},
|
|
|
|
|
{MKLDNN_ARG_SCALE_SHIFT, *scaleshift_memory},
|
|
|
|
|
{MKLDNN_ARG_DIFF_SRC, *diff_src_memory},
|
|
|
|
|
{MKLDNN_ARG_DIFF_SCALE_SHIFT, *diff_scaleshift_memory}});
|
|
|
|
|
astream.wait();
|
|
|
|
|
|
|
|
|
|
T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|