|
|
|
@ -62,56 +62,42 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<batch_norm_fwd> AcquireTestBatchNormFwd(
|
|
|
|
|
std::shared_ptr<batch_norm_fwd> AcquireTestTrainingBatchNormFwd(
|
|
|
|
|
std::shared_ptr<memory> src_memory,
|
|
|
|
|
const mkldnn::primitive::at &mean_memory,
|
|
|
|
|
const mkldnn::primitive::at &variance_memory,
|
|
|
|
|
std::shared_ptr<memory> scaleshift_memory,
|
|
|
|
|
std::shared_ptr<memory> dst_memory) {
|
|
|
|
|
std::shared_ptr<memory> dst_memory, std::shared_ptr<memory> mean_memory,
|
|
|
|
|
std::shared_ptr<memory> variance_memory, bool is_test) {
|
|
|
|
|
auto prim_key = key_ + "@batch_norm_p";
|
|
|
|
|
auto batch_norm_p =
|
|
|
|
|
std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
(batch_norm_p != nullptr) || (is_reusing_ == false),
|
|
|
|
|
"Fail to find batch norm primitive for test in device context");
|
|
|
|
|
if (batch_norm_p == nullptr) {
|
|
|
|
|
batch_norm_p = std::make_shared<batch_norm_fwd>(
|
|
|
|
|
*batch_norm_pd_, *src_memory, mean_memory, variance_memory,
|
|
|
|
|
*scaleshift_memory, *dst_memory);
|
|
|
|
|
|
|
|
|
|
dev_ctx_.SetBlob(prim_key, batch_norm_p);
|
|
|
|
|
} else {
|
|
|
|
|
is_reusing_ = true;
|
|
|
|
|
}
|
|
|
|
|
return batch_norm_p;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE((batch_norm_p != nullptr) || !is_reusing_,
|
|
|
|
|
"Fail to find batch norm primitive in device context");
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<batch_norm_fwd> AcquireTrainingBatchNormFwd(
|
|
|
|
|
std::shared_ptr<memory> src_memory,
|
|
|
|
|
std::shared_ptr<memory> scaleshift_memory,
|
|
|
|
|
std::shared_ptr<memory> dst_memory, std::shared_ptr<memory> mean_memory,
|
|
|
|
|
std::shared_ptr<memory> variance_memory) {
|
|
|
|
|
auto prim_key = key_ + "@batch_norm_p";
|
|
|
|
|
auto batch_norm_p =
|
|
|
|
|
std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
(batch_norm_p != nullptr) || (is_reusing_ == false),
|
|
|
|
|
"Fail to find batch norm primitive for training in device context");
|
|
|
|
|
if (batch_norm_p == nullptr) {
|
|
|
|
|
batch_norm_p = std::make_shared<batch_norm_fwd>(
|
|
|
|
|
*batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory,
|
|
|
|
|
*mean_memory, *variance_memory);
|
|
|
|
|
if (is_test) {
|
|
|
|
|
batch_norm_p = std::make_shared<batch_norm_fwd>(
|
|
|
|
|
*batch_norm_pd_, *src_memory,
|
|
|
|
|
(const mkldnn::primitive::at &)*mean_memory,
|
|
|
|
|
(const mkldnn::primitive::at &)*variance_memory, *scaleshift_memory,
|
|
|
|
|
*dst_memory);
|
|
|
|
|
} else {
|
|
|
|
|
batch_norm_p = std::make_shared<batch_norm_fwd>(
|
|
|
|
|
*batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory,
|
|
|
|
|
*mean_memory, *variance_memory);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dev_ctx_.SetBlob(prim_key, batch_norm_p);
|
|
|
|
|
} else {
|
|
|
|
|
is_reusing_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return batch_norm_p;
|
|
|
|
|
}
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
static std::string GetHash(const memory::dims &input_dims, float epsilon,
|
|
|
|
|
unsigned flag, bool is_test, memory::format format,
|
|
|
|
|
const std::string &suffix) {
|
|
|
|
|
const std::string &suffix = "") {
|
|
|
|
|
auto dims2str = [](const memory::dims &operand_dims) {
|
|
|
|
|
std::string dstr = "";
|
|
|
|
|
for (size_t i = 0; i < operand_dims.size(); ++i) {
|
|
|
|
@ -128,19 +114,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_pd_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::string gethash(const memory::dims &input_dims, float epsilon,
|
|
|
|
|
unsigned flag, bool is_test, memory::format format) {
|
|
|
|
|
auto dims2str = [](const memory::dims &operand_dims) {
|
|
|
|
|
std::string dstr = "";
|
|
|
|
|
for (size_t i = 0; i < operand_dims.size(); ++i) {
|
|
|
|
|
dstr += std::to_string(operand_dims[i]) + "-";
|
|
|
|
|
}
|
|
|
|
|
return dstr;
|
|
|
|
|
};
|
|
|
|
|
return dims2str(input_dims) + std::to_string(epsilon) + std::to_string(flag) +
|
|
|
|
|
std::to_string(is_test) + std::to_string(format);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<memory> UpdateMemoryData(
|
|
|
|
|
const platform::MKLDNNDeviceContext &dev_ctx, const std::string &key,
|
|
|
|
|
void *new_ptr) {
|
|
|
|
@ -274,10 +247,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
handler.AcquireVarianceMemoryFromPrimitive(
|
|
|
|
|
to_void_cast(variance_data));
|
|
|
|
|
|
|
|
|
|
batch_norm_p = handler.AcquireTestBatchNormFwd(
|
|
|
|
|
src_memory, (const mkldnn::primitive::at &)*mean_memory,
|
|
|
|
|
(const mkldnn::primitive::at &)*variance_memory, scaleshift_memory,
|
|
|
|
|
dst_memory);
|
|
|
|
|
batch_norm_p = handler.AcquireTestTrainingBatchNormFwd(
|
|
|
|
|
src_memory, scaleshift_memory, dst_memory, mean_memory,
|
|
|
|
|
variance_memory, true);
|
|
|
|
|
} else {
|
|
|
|
|
// create mkldnn memory for stats (as output)
|
|
|
|
|
std::shared_ptr<memory> mean_memory =
|
|
|
|
@ -285,9 +257,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
std::shared_ptr<memory> variance_memory =
|
|
|
|
|
handler.AcquireVarianceMemoryFromPrimitive(batch_variance_data);
|
|
|
|
|
|
|
|
|
|
batch_norm_p = handler.AcquireTrainingBatchNormFwd(
|
|
|
|
|
batch_norm_p = handler.AcquireTestTrainingBatchNormFwd(
|
|
|
|
|
src_memory, scaleshift_memory, dst_memory, mean_memory,
|
|
|
|
|
variance_memory);
|
|
|
|
|
variance_memory, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
y->set_layout(DataLayout::kMKLDNN);
|
|
|
|
@ -377,7 +349,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// keys for primitives reuse
|
|
|
|
|
const std::string key_with_hash =
|
|
|
|
|
key + gethash(src_tz, epsilon, flags, false, input_format);
|
|
|
|
|
key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false,
|
|
|
|
|
input_format);
|
|
|
|
|
const std::string key_batch_norm_bwd_p =
|
|
|
|
|
key_with_hash + "@batch_norm_bwd_p";
|
|
|
|
|
const std::string key_batch_norm_src_mem_p =
|
|
|
|
|