|
|
|
@ -25,11 +25,18 @@ namespace paddle {
|
|
|
|
|
|
|
|
|
|
bool MkldnnLayer::init(const LayerMap& layerMap,
|
|
|
|
|
const ParameterMap& parameterMap) {
|
|
|
|
|
if (!Layer::init(layerMap, parameterMap)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn."
|
|
|
|
|
<< "Please set WITH_MKLDNN=ON "
|
|
|
|
|
<< "and set use_mkldnn=True";
|
|
|
|
|
stream_.reset(new MkldnnStream());
|
|
|
|
|
engine_ = CpuEngine::Instance().getEngine();
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): deivecId
|
|
|
|
|
return Layer::init(layerMap, parameterMap);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MkldnnLayer::resetForwardFC(int bs,
|
|
|
|
@ -42,7 +49,6 @@ void MkldnnLayer::resetForwardFC(int bs,
|
|
|
|
|
real* wgtData,
|
|
|
|
|
real* biasData) {
|
|
|
|
|
bool hasSpatial = ih == 1 && iw == 1 ? false : true;
|
|
|
|
|
engine_ = CpuEngine::Instance().getEngine();
|
|
|
|
|
|
|
|
|
|
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
|
|
|
|
|
: createMD({bs, ic}, format::nc);
|
|
|
|
@ -52,21 +58,21 @@ void MkldnnLayer::resetForwardFC(int bs,
|
|
|
|
|
: createMD({}, format::format_undef);
|
|
|
|
|
mem::desc topMD = createMD({bs, oc}, format::nc);
|
|
|
|
|
|
|
|
|
|
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
|
|
|
|
|
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
|
|
|
|
|
outVal_.reset(new mem(mem::primitive_desc(topMD, engine_), topData));
|
|
|
|
|
|
|
|
|
|
mkldnn::prop_kind pk = mkldnn::prop_kind::forward;
|
|
|
|
|
fc_fwd::desc fwdDesc = biasData != NULL
|
|
|
|
|
? fc_fwd::desc(pk, botMD, wgtMD, biasMD, topMD)
|
|
|
|
|
: fc_fwd::desc(pk, botMD, wgtMD, topMD);
|
|
|
|
|
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
|
|
|
|
|
|
|
|
|
|
mem bot = mem(mem::primitive_desc(botMD, engine_), botData);
|
|
|
|
|
mem wgt = mem(mem::primitive_desc(wgtMD, engine_), wgtData);
|
|
|
|
|
mem top = mem(mem::primitive_desc(topMD, engine_), topData);
|
|
|
|
|
|
|
|
|
|
if (biasData != NULL) {
|
|
|
|
|
mem bias = mem(mem::primitive_desc(biasMD, engine_), biasData);
|
|
|
|
|
fwd_.reset(new fc_fwd(fwdPD, bot, wgt, bias, top));
|
|
|
|
|
biasVal_.reset(new mem(mem::primitive_desc(biasMD, engine_), biasData));
|
|
|
|
|
fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *biasVal_, *outVal_));
|
|
|
|
|
} else {
|
|
|
|
|
fwd_.reset(new fc_fwd(fwdPD, bot, wgt, top));
|
|
|
|
|
fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *outVal_));
|
|
|
|
|
}
|
|
|
|
|
pipelineFwd_.clear();
|
|
|
|
|
pipelineFwd_.push_back(*fwd_);
|
|
|
|
@ -84,8 +90,12 @@ void MkldnnLayer::mkldnnForwardFC(int bs,
|
|
|
|
|
// if input size changed, reset it
|
|
|
|
|
resetForwardFC(bs, ic, ih, iw, botData, oc, topData, wgtData, biasData);
|
|
|
|
|
|
|
|
|
|
this->cvtWgtFromPaddle();
|
|
|
|
|
|
|
|
|
|
// update input, since the data might be changed if this is after data layer
|
|
|
|
|
inVal_->set_data_handle(botData);
|
|
|
|
|
|
|
|
|
|
// just forward
|
|
|
|
|
// update botdata
|
|
|
|
|
stream_->submit(pipelineFwd_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -112,6 +122,10 @@ void MkldnnLayer::resetBackwardFC(int bs,
|
|
|
|
|
mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x)
|
|
|
|
|
: createMD({}, format::format_undef);
|
|
|
|
|
|
|
|
|
|
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
|
|
|
|
|
wgtGrad_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtDiff));
|
|
|
|
|
outGrad_.reset(new mem(mem::primitive_desc(topMD, engine_), topDiff));
|
|
|
|
|
|
|
|
|
|
fc_fwd::desc fwdDesc =
|
|
|
|
|
fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD);
|
|
|
|
|
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
|
|
|
|
@ -121,15 +135,12 @@ void MkldnnLayer::resetBackwardFC(int bs,
|
|
|
|
|
fc_bwdWgt::primitive_desc bwdWgtPD =
|
|
|
|
|
fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD);
|
|
|
|
|
|
|
|
|
|
mem botVal = mem(mem::primitive_desc(botMD, engine_), botData);
|
|
|
|
|
mem wgtGrad = mem(mem::primitive_desc(wgtMD, engine_), wgtDiff);
|
|
|
|
|
mem topGrad = mem(mem::primitive_desc(topMD, engine_), topDiff);
|
|
|
|
|
|
|
|
|
|
if (biasDiff != NULL) {
|
|
|
|
|
mem biasGrad = mem(mem::primitive_desc(biasMD, engine_), biasDiff);
|
|
|
|
|
bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, botVal, topGrad, wgtGrad, biasGrad));
|
|
|
|
|
biasGrad_.reset(new mem(mem::primitive_desc(biasMD, engine_), biasDiff));
|
|
|
|
|
bwdWgt_.reset(
|
|
|
|
|
new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_, *biasGrad_));
|
|
|
|
|
} else {
|
|
|
|
|
bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, botVal, topGrad, wgtGrad));
|
|
|
|
|
bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_));
|
|
|
|
|
}
|
|
|
|
|
pipelineBwd_.clear();
|
|
|
|
|
pipelineBwd_.push_back(*bwdWgt_);
|
|
|
|
@ -142,9 +153,9 @@ void MkldnnLayer::resetBackwardFC(int bs,
|
|
|
|
|
fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(botMD, wgtMD, topMD);
|
|
|
|
|
fc_bwdData::primitive_desc bwdDataPD =
|
|
|
|
|
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
|
|
|
|
|
mem botGrad = mem(mem::primitive_desc(botMD, engine_), botDiff);
|
|
|
|
|
mem wgtVal = mem(mem::primitive_desc(wgtMD, engine_), wgtData);
|
|
|
|
|
bwdData_.reset(new fc_bwdData(bwdDataPD, topGrad, wgtVal, botGrad));
|
|
|
|
|
inGrad_.reset(new mem(mem::primitive_desc(botMD, engine_), botDiff));
|
|
|
|
|
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
|
|
|
|
|
bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_));
|
|
|
|
|
pipelineBwd_.push_back(*bwdData_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -172,11 +183,18 @@ void MkldnnLayer::mkldnnBackwardFC(int bs,
|
|
|
|
|
wgtData,
|
|
|
|
|
biasDiff);
|
|
|
|
|
|
|
|
|
|
// just forward
|
|
|
|
|
// update botdata
|
|
|
|
|
// update data
|
|
|
|
|
outGrad_->set_data_handle(topDiff);
|
|
|
|
|
|
|
|
|
|
stream_->submit(pipelineBwd_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MkldnnLayer::printSizeInfo() {
|
|
|
|
|
VLOG(DNN_SIZES) << "bs: " << bs_ << ", ic: " << ic_ << ", ih: " << ih_
|
|
|
|
|
<< ", iw: " << iw_ << ", oc: " << oc_ << ", oh: " << oh_
|
|
|
|
|
<< ", ow: " << ow_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mem::desc MkldnnLayer::createMD(mem::dims dims,
|
|
|
|
|
mem::format fmt,
|
|
|
|
|
mem::data_type type) {
|
|
|
|
|