|
|
|
@ -49,7 +49,6 @@ void MkldnnLayer::resetForwardFC(int bs,
|
|
|
|
|
real* wgtData,
|
|
|
|
|
real* biasData) {
|
|
|
|
|
bool hasSpatial = ih == 1 && iw == 1 ? false : true;
|
|
|
|
|
|
|
|
|
|
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
|
|
|
|
|
: createMD({bs, ic}, format::nc);
|
|
|
|
|
mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw)
|
|
|
|
@ -58,7 +57,12 @@ void MkldnnLayer::resetForwardFC(int bs,
|
|
|
|
|
: createMD({}, format::format_undef);
|
|
|
|
|
mem::desc topMD = createMD({bs, oc}, format::nc);
|
|
|
|
|
|
|
|
|
|
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
|
|
|
|
|
mem::primitive_desc botPD = mem::primitive_desc(botMD, engine_);
|
|
|
|
|
if (inVal_ && inVal_->get_primitive_desc() == botPD) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inVal_.reset(new mem(botPD, botData));
|
|
|
|
|
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
|
|
|
|
|
outVal_.reset(new mem(mem::primitive_desc(topMD, engine_), topData));
|
|
|
|
|
|
|
|
|
@ -111,7 +115,6 @@ void MkldnnLayer::resetBackwardFC(int bs,
|
|
|
|
|
real* wgtData,
|
|
|
|
|
real* biasDiff) {
|
|
|
|
|
bool hasSpatial = ih == 1 && iw == 1 ? false : true;
|
|
|
|
|
engine_ = CpuEngine::Instance().getEngine();
|
|
|
|
|
|
|
|
|
|
// backward weight
|
|
|
|
|
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
|
|
|
|
@ -122,9 +125,19 @@ void MkldnnLayer::resetBackwardFC(int bs,
|
|
|
|
|
mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x)
|
|
|
|
|
: createMD({}, format::format_undef);
|
|
|
|
|
|
|
|
|
|
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
|
|
|
|
|
mem::primitive_desc topPD = mem::primitive_desc(botMD, engine_);
|
|
|
|
|
if (outGrad_ && outGrad_->get_primitive_desc() == topPD) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (inVal_) {
|
|
|
|
|
// update data
|
|
|
|
|
inVal_->set_data_handle(botData);
|
|
|
|
|
} else {
|
|
|
|
|
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
|
|
|
|
|
}
|
|
|
|
|
wgtGrad_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtDiff));
|
|
|
|
|
outGrad_.reset(new mem(mem::primitive_desc(topMD, engine_), topDiff));
|
|
|
|
|
outGrad_.reset(new mem(topPD, topDiff));
|
|
|
|
|
|
|
|
|
|
fc_fwd::desc fwdDesc =
|
|
|
|
|
fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD);
|
|
|
|
@ -154,7 +167,12 @@ void MkldnnLayer::resetBackwardFC(int bs,
|
|
|
|
|
fc_bwdData::primitive_desc bwdDataPD =
|
|
|
|
|
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
|
|
|
|
|
inGrad_.reset(new mem(mem::primitive_desc(botMD, engine_), botDiff));
|
|
|
|
|
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
|
|
|
|
|
if (wgtVal_) {
|
|
|
|
|
// update data
|
|
|
|
|
wgtVal_->set_data_handle(wgtData);
|
|
|
|
|
} else {
|
|
|
|
|
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
|
|
|
|
|
}
|
|
|
|
|
bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_));
|
|
|
|
|
pipelineBwd_.push_back(*bwdData_);
|
|
|
|
|
}
|
|
|
|
|