skip reset mkldnn when input size does not change

revert-3824-remove_grad_op_type
tensor-tang 8 years ago
parent 6373291c77
commit e18fbd8208

@ -49,7 +49,6 @@ void MkldnnLayer::resetForwardFC(int bs,
real* wgtData,
real* biasData) {
bool hasSpatial = ih == 1 && iw == 1 ? false : true;
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
: createMD({bs, ic}, format::nc);
mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw)
@ -58,7 +57,12 @@ void MkldnnLayer::resetForwardFC(int bs,
: createMD({}, format::format_undef);
mem::desc topMD = createMD({bs, oc}, format::nc);
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
mem::primitive_desc botPD = mem::primitive_desc(botMD, engine_);
if (inVal_ && inVal_->get_primitive_desc() == botPD) {
return;
}
inVal_.reset(new mem(botPD, botData));
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
outVal_.reset(new mem(mem::primitive_desc(topMD, engine_), topData));
@ -111,7 +115,6 @@ void MkldnnLayer::resetBackwardFC(int bs,
real* wgtData,
real* biasDiff) {
bool hasSpatial = ih == 1 && iw == 1 ? false : true;
engine_ = CpuEngine::Instance().getEngine();
// backward weight
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
@ -122,9 +125,19 @@ void MkldnnLayer::resetBackwardFC(int bs,
mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x)
: createMD({}, format::format_undef);
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
mem::primitive_desc topPD = mem::primitive_desc(botMD, engine_);
if (outGrad_ && outGrad_->get_primitive_desc() == topPD) {
return;
}
if (inVal_) {
// update data
inVal_->set_data_handle(botData);
} else {
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
}
wgtGrad_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtDiff));
outGrad_.reset(new mem(mem::primitive_desc(topMD, engine_), topDiff));
outGrad_.reset(new mem(topPD, topDiff));
fc_fwd::desc fwdDesc =
fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD);
@ -154,7 +167,12 @@ void MkldnnLayer::resetBackwardFC(int bs,
fc_bwdData::primitive_desc bwdDataPD =
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
inGrad_.reset(new mem(mem::primitive_desc(botMD, engine_), botDiff));
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
if (wgtVal_) {
// update data
wgtVal_->set_data_handle(wgtData);
} else {
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
}
bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_));
pipelineBwd_.push_back(*bwdData_);
}

Loading…
Cancel
Save