|
|
|
@ -21,8 +21,6 @@ namespace paddle {
|
|
|
|
|
|
|
|
|
|
REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer);
|
|
|
|
|
|
|
|
|
|
const real MKLDNNBatchNormLayer::EPS = 1E-5;
|
|
|
|
|
|
|
|
|
|
bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
|
|
|
|
|
const ParameterMap& parameterMap) {
|
|
|
|
|
if (!MKLDNNLayer::init(layerMap, parameterMap)) {
|
|
|
|
@ -50,6 +48,8 @@ bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
|
|
|
|
|
useGlobalStats_ = config_.use_global_stats();
|
|
|
|
|
}
|
|
|
|
|
movingAvgFraction_ = config_.moving_average_fraction();
|
|
|
|
|
epsilon_ = config_.epsilon();
|
|
|
|
|
|
|
|
|
|
VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use")
|
|
|
|
|
<< " --- global stats";
|
|
|
|
|
VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_;
|
|
|
|
@ -210,7 +210,7 @@ void MKLDNNBatchNormLayer::resetFwdPD(
|
|
|
|
|
if (wgt) {
|
|
|
|
|
flags_ = (flags_ | batch_normalization_flag::use_scale_shift);
|
|
|
|
|
}
|
|
|
|
|
auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_);
|
|
|
|
|
auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), epsilon_, flags_);
|
|
|
|
|
pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_));
|
|
|
|
|
CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc());
|
|
|
|
|
if (wgt) {
|
|
|
|
@ -277,7 +277,7 @@ void MKLDNNBatchNormLayer::resetBwdPD(
|
|
|
|
|
}
|
|
|
|
|
CHECK_PRIMITIVE_DESC_EQ(out, in->getPrimitiveDesc());
|
|
|
|
|
auto md = in->getMemoryDesc();
|
|
|
|
|
auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_);
|
|
|
|
|
auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, epsilon_, flags_);
|
|
|
|
|
pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_));
|
|
|
|
|
CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc());
|
|
|
|
|
CHECK_PRIMITIVE_DESC_EQ(wgt, pd->diff_weights_primitive_desc());
|
|
|
|
|