|
|
|
@ -31,22 +31,45 @@ class BatchNormMKLDNNHandler
|
|
|
|
|
: public platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
|
|
|
|
|
mkldnn::batch_normalization_backward> {
|
|
|
|
|
public:
|
|
|
|
|
BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon,
|
|
|
|
|
const mkldnn::normalization_flags &flags,
|
|
|
|
|
const bool &global_stats, const MKLDNNMemoryFormat fmt,
|
|
|
|
|
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
|
|
|
|
|
const platform::MKLDNNDeviceContext &dev_ctx,
|
|
|
|
|
platform::Place cpu_place,
|
|
|
|
|
const std::string &uniq_name)
|
|
|
|
|
const mkldnn::engine mkldnn_engine,
|
|
|
|
|
platform::Place cpu_place, const Tensor *x,
|
|
|
|
|
const bool global_stats, const bool test_mode,
|
|
|
|
|
const std::string &unique_name)
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
|
|
|
|
|
mkldnn::batch_normalization_backward>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(dims, uniq_name)) {
|
|
|
|
|
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
|
|
|
|
|
this->AcquireForwardPrimitiveDescriptor(
|
|
|
|
|
global_stats == true ? mkldnn::prop_kind::forward_scoring
|
|
|
|
|
: mkldnn::prop_kind::forward_training,
|
|
|
|
|
md, epsilon, flags);
|
|
|
|
|
platform::CreateKey(framework::vectorize(x->dims()), unique_name)) {
|
|
|
|
|
if (!this->isCached()) {
|
|
|
|
|
const float epsilon = ctx.Attr<float>("epsilon");
|
|
|
|
|
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x->layout(), DataLayout::kMKLDNN,
|
|
|
|
|
platform::errors::InvalidArgument("Wrong layout set for X tensor"));
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
x->format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
platform::errors::InvalidArgument("Wrong format set for X tensor"));
|
|
|
|
|
|
|
|
|
|
auto src_tz = paddle::framework::vectorize(x->dims());
|
|
|
|
|
|
|
|
|
|
// Flags are added by bitwise OR operation
|
|
|
|
|
auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
|
|
|
|
|
if (global_stats)
|
|
|
|
|
flags |= mkldnn::normalization_flags::use_global_stats; // 010
|
|
|
|
|
if (fuse_with_relu && test_mode)
|
|
|
|
|
flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
|
|
|
|
|
|
|
|
|
|
auto md = mkldnn::memory::desc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), x->format()));
|
|
|
|
|
|
|
|
|
|
this->AcquireForwardPrimitiveDescriptor(
|
|
|
|
|
global_stats == true ? mkldnn::prop_kind::forward_scoring
|
|
|
|
|
: mkldnn::prop_kind::forward_training,
|
|
|
|
|
md, epsilon, flags);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon,
|
|
|
|
|
const mkldnn::normalization_flags &flags,
|
|
|
|
@ -68,9 +91,30 @@ class BatchNormMKLDNNHandler
|
|
|
|
|
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, flags);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(T *scaleshift_data) {
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(
|
|
|
|
|
this->fwd_pd_->weights_desc(), scaleshift_data, "@scaleshift_mem_p");
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(const Tensor *scale,
|
|
|
|
|
const Tensor *shift,
|
|
|
|
|
const bool is_test) {
|
|
|
|
|
auto scaleshift_memory = this->AcquireMemory("@scaleshift_mem_p");
|
|
|
|
|
if (scaleshift_memory == nullptr || !is_test) {
|
|
|
|
|
auto scale_tz = paddle::framework::vectorize(scale->dims());
|
|
|
|
|
const unsigned int C = scale_tz[0];
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
scale_tz.size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Dims of scale tensor must be 1, but received scale's size is %d",
|
|
|
|
|
scale_tz.size()));
|
|
|
|
|
|
|
|
|
|
auto mem_p = this->AcquireMemoryFromPrimitive(
|
|
|
|
|
this->fwd_pd_->weights_desc(), "@scaleshift_mem_p");
|
|
|
|
|
|
|
|
|
|
// MKLDNN requires a single piece of memory for scale and shift/bias data
|
|
|
|
|
auto mem_data_handle = reinterpret_cast<T *>(mem_p->get_data_handle());
|
|
|
|
|
std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle);
|
|
|
|
|
std::copy(shift->data<T>(), shift->data<T>() + C, mem_data_handle + C);
|
|
|
|
|
|
|
|
|
|
return mem_p;
|
|
|
|
|
}
|
|
|
|
|
return scaleshift_memory;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory(
|
|
|
|
@ -115,64 +159,30 @@ template <typename T>
|
|
|
|
|
class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
const float epsilon = ctx.Attr<float>("epsilon");
|
|
|
|
|
const float momentum = ctx.Attr<float>("momentum");
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
|
|
|
|
const auto &mkldnn_engine = dev_ctx.GetEngine();
|
|
|
|
|
|
|
|
|
|
const bool is_test = ctx.Attr<bool>("is_test");
|
|
|
|
|
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
|
|
|
|
|
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
|
|
|
|
|
const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
|
|
|
|
|
bool test_mode = is_test && (!trainable_stats);
|
|
|
|
|
|
|
|
|
|
bool global_stats = test_mode || use_global_stats;
|
|
|
|
|
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
|
|
|
|
const bool test_mode = is_test && (!trainable_stats);
|
|
|
|
|
const bool global_stats = test_mode || use_global_stats;
|
|
|
|
|
|
|
|
|
|
const auto *x = ctx.Input<Tensor>("X");
|
|
|
|
|
const auto *scale = ctx.Input<Tensor>("Scale");
|
|
|
|
|
const auto *shift = ctx.Input<Tensor>("Bias");
|
|
|
|
|
|
|
|
|
|
auto *y = ctx.Output<Tensor>("Y");
|
|
|
|
|
auto *mean_out = ctx.Output<Tensor>("MeanOut");
|
|
|
|
|
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
|
|
|
|
|
auto *batch_mean = ctx.Output<Tensor>("SavedMean");
|
|
|
|
|
auto *batch_variance = ctx.Output<Tensor>("SavedVariance");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
|
|
|
|
|
"Wrong layout set for X tensor");
|
|
|
|
|
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
"Wrong format set for X tensor");
|
|
|
|
|
|
|
|
|
|
auto src_tz = paddle::framework::vectorize<int64_t>(x->dims());
|
|
|
|
|
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
scale_tz.size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Dims of scale tensor must be 1, but received scale's size is %d",
|
|
|
|
|
scale_tz.size()));
|
|
|
|
|
const unsigned int C = scale_tz[0];
|
|
|
|
|
|
|
|
|
|
// MKLDNN requires a single piece of memory for scale and shift/bias data
|
|
|
|
|
|
|
|
|
|
std::vector<T> scaleshift_data(scale->data<T>(), scale->data<T>() + C);
|
|
|
|
|
scaleshift_data.reserve(2 * C);
|
|
|
|
|
scaleshift_data.insert(scaleshift_data.end(), shift->data<T>(),
|
|
|
|
|
shift->data<T>() + C);
|
|
|
|
|
|
|
|
|
|
// Flags are added by bitwise OR operation
|
|
|
|
|
auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
|
|
|
|
|
if (global_stats)
|
|
|
|
|
flags |= mkldnn::normalization_flags::use_global_stats; // 010
|
|
|
|
|
if (fuse_with_relu && test_mode)
|
|
|
|
|
flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
|
|
|
|
|
|
|
|
|
|
BatchNormMKLDNNHandler<T> handler(
|
|
|
|
|
src_tz, epsilon, flags, global_stats,
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), x->format()), dev_ctx,
|
|
|
|
|
ctx.GetPlace(), ctx.OutputName("SavedMean"));
|
|
|
|
|
BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine,
|
|
|
|
|
ctx.GetPlace(), x, global_stats,
|
|
|
|
|
test_mode, ctx.OutputName("SavedMean"));
|
|
|
|
|
|
|
|
|
|
auto src_memory = handler.AcquireSrcMemory(x);
|
|
|
|
|
auto scaleshift_memory =
|
|
|
|
|
handler.AcquireScaleShiftMemory(scaleshift_data.data());
|
|
|
|
|
handler.AcquireScaleShiftMemory(scale, shift, is_test);
|
|
|
|
|
auto dst_memory = handler.AcquireDstMemory(y);
|
|
|
|
|
|
|
|
|
|
auto batch_norm_p = handler.AcquireForwardPrimitive();
|
|
|
|
@ -206,6 +216,12 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
astream.wait();
|
|
|
|
|
|
|
|
|
|
if (!global_stats) {
|
|
|
|
|
auto *mean_out = ctx.Output<Tensor>("MeanOut");
|
|
|
|
|
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
|
|
|
|
|
const float momentum = ctx.Attr<float>("momentum");
|
|
|
|
|
|
|
|
|
|
const unsigned int C = paddle::framework::vectorize(scale->dims())[0];
|
|
|
|
|
|
|
|
|
|
// mkldnn only compute stats for current batch
|
|
|
|
|
// so we need compute momentum stats via Eigen lib
|
|
|
|
|
EigenVectorArrayMap<T> batch_mean_e(
|
|
|
|
@ -273,11 +289,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// MKLDNN requires a single piece of memory for scale and shift/bias data
|
|
|
|
|
const size_t scaleshift_size = 2 * C;
|
|
|
|
|
std::vector<T> scaleshift_data(scale->data<T>(), scale->data<T>() + C);
|
|
|
|
|
scaleshift_data.reserve(scaleshift_size);
|
|
|
|
|
scaleshift_data.insert(scaleshift_data.end(), shift->data<T>(),
|
|
|
|
|
shift->data<T>() + C);
|
|
|
|
|
|
|
|
|
|
std::vector<T> diff_scaleshift_data;
|
|
|
|
|
diff_scaleshift_data.reserve(scaleshift_size);
|
|
|
|
|
|
|
|
|
@ -286,7 +297,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto variance_memory = handler.AcquireVarianceMemory(batch_variance);
|
|
|
|
|
auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y);
|
|
|
|
|
auto scaleshift_memory =
|
|
|
|
|
handler.AcquireScaleShiftMemory(scaleshift_data.data());
|
|
|
|
|
handler.AcquireScaleShiftMemory(scale, shift, false);
|
|
|
|
|
auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x);
|
|
|
|
|
auto diff_scaleshift_memory =
|
|
|
|
|
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());
|
|
|
|
|