You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
223 lines
8.0 KiB
223 lines
8.0 KiB
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#include "MkldnnLayer.h"
|
|
|
|
using mem = mkldnn::memory; // NOLINT
|
|
typedef mem::format format;
|
|
typedef mkldnn::inner_product_forward fc_fwd;
|
|
typedef mkldnn::inner_product_backward_weights fc_bwdWgt;
|
|
typedef mkldnn::inner_product_backward_data fc_bwdData;
|
|
|
|
namespace paddle {
|
|
|
|
bool MkldnnLayer::init(const LayerMap& layerMap,
|
|
const ParameterMap& parameterMap) {
|
|
if (!Layer::init(layerMap, parameterMap)) {
|
|
return false;
|
|
}
|
|
|
|
CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn."
|
|
<< "Please set WITH_MKLDNN=ON "
|
|
<< "and set use_mkldnn=True";
|
|
stream_.reset(new MkldnnStream());
|
|
engine_ = CpuEngine::Instance().getEngine();
|
|
|
|
// TODO(TJ): deivecId
|
|
return true;
|
|
}
|
|
|
|
void MkldnnLayer::resetForwardFC(int bs,
|
|
int ic,
|
|
int ih,
|
|
int iw,
|
|
real* botData,
|
|
int oc,
|
|
real* topData,
|
|
real* wgtData,
|
|
real* biasData) {
|
|
bool hasSpatial = ih == 1 && iw == 1 ? false : true;
|
|
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
|
|
: createMD({bs, ic}, format::nc);
|
|
mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw)
|
|
: createMD({oc, ic}, format::oi);
|
|
mem::desc biasMD = biasData != NULL ? createMD({oc}, format::x)
|
|
: createMD({}, format::format_undef);
|
|
mem::desc topMD = createMD({bs, oc}, format::nc);
|
|
|
|
mem::primitive_desc botPD = mem::primitive_desc(botMD, engine_);
|
|
if (inVal_ && inVal_->get_primitive_desc() == botPD) {
|
|
return;
|
|
}
|
|
|
|
inVal_.reset(new mem(botPD, botData));
|
|
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
|
|
outVal_.reset(new mem(mem::primitive_desc(topMD, engine_), topData));
|
|
|
|
mkldnn::prop_kind pk = mkldnn::prop_kind::forward;
|
|
fc_fwd::desc fwdDesc = biasData != NULL
|
|
? fc_fwd::desc(pk, botMD, wgtMD, biasMD, topMD)
|
|
: fc_fwd::desc(pk, botMD, wgtMD, topMD);
|
|
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
|
|
|
|
if (biasData != NULL) {
|
|
biasVal_.reset(new mem(mem::primitive_desc(biasMD, engine_), biasData));
|
|
fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *biasVal_, *outVal_));
|
|
} else {
|
|
fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *outVal_));
|
|
}
|
|
pipelineFwd_.clear();
|
|
pipelineFwd_.push_back(*fwd_);
|
|
}
|
|
|
|
void MkldnnLayer::mkldnnForwardFC(int bs,
|
|
int ic,
|
|
int ih,
|
|
int iw,
|
|
real* botData,
|
|
int oc,
|
|
real* topData,
|
|
real* wgtData,
|
|
real* biasData) {
|
|
// if input size changed, reset it
|
|
resetForwardFC(bs, ic, ih, iw, botData, oc, topData, wgtData, biasData);
|
|
|
|
this->convertWeightsFromPaddle();
|
|
|
|
// update input, since the data might be changed if this is after data layer
|
|
inVal_->set_data_handle(botData);
|
|
|
|
// just forward
|
|
stream_->submit(pipelineFwd_);
|
|
}
|
|
|
|
void MkldnnLayer::resetBackwardFC(int bs,
|
|
int ic,
|
|
int ih,
|
|
int iw,
|
|
real* botDiff,
|
|
real* botData,
|
|
int oc,
|
|
real* topDiff,
|
|
real* wgtDiff,
|
|
real* wgtData,
|
|
real* biasDiff) {
|
|
bool hasSpatial = ih == 1 && iw == 1 ? false : true;
|
|
|
|
// backward weight
|
|
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
|
|
: createMD({bs, ic}, format::nc);
|
|
mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw)
|
|
: createMD({oc, ic}, format::oi);
|
|
mem::desc topMD = createMD({bs, oc}, format::nc);
|
|
mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x)
|
|
: createMD({}, format::format_undef);
|
|
|
|
mem::primitive_desc topPD = mem::primitive_desc(botMD, engine_);
|
|
if (outGrad_ && outGrad_->get_primitive_desc() == topPD) {
|
|
return;
|
|
}
|
|
|
|
if (inVal_) {
|
|
// update data
|
|
inVal_->set_data_handle(botData);
|
|
} else {
|
|
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
|
|
}
|
|
wgtGrad_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtDiff));
|
|
outGrad_.reset(new mem(topPD, topDiff));
|
|
|
|
fc_fwd::desc fwdDesc =
|
|
fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD);
|
|
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
|
|
fc_bwdWgt::desc bwdWgtDesc =
|
|
biasDiff != NULL ? fc_bwdWgt::desc(botMD, wgtMD, biasMD, topMD)
|
|
: fc_bwdWgt::desc(botMD, wgtMD, topMD);
|
|
fc_bwdWgt::primitive_desc bwdWgtPD =
|
|
fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD);
|
|
|
|
if (biasDiff != NULL) {
|
|
biasGrad_.reset(new mem(mem::primitive_desc(biasMD, engine_), biasDiff));
|
|
bwdWgt_.reset(
|
|
new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_, *biasGrad_));
|
|
} else {
|
|
bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_));
|
|
}
|
|
pipelineBwd_.clear();
|
|
pipelineBwd_.push_back(*bwdWgt_);
|
|
|
|
// backward data
|
|
if (botDiff == NULL) {
|
|
return;
|
|
}
|
|
|
|
fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(botMD, wgtMD, topMD);
|
|
fc_bwdData::primitive_desc bwdDataPD =
|
|
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
|
|
inGrad_.reset(new mem(mem::primitive_desc(botMD, engine_), botDiff));
|
|
if (wgtVal_) {
|
|
// update data
|
|
wgtVal_->set_data_handle(wgtData);
|
|
} else {
|
|
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
|
|
}
|
|
bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_));
|
|
pipelineBwd_.push_back(*bwdData_);
|
|
}
|
|
|
|
void MkldnnLayer::mkldnnBackwardFC(int bs,
|
|
int ic,
|
|
int ih,
|
|
int iw,
|
|
real* botDiff,
|
|
real* botData,
|
|
int oc,
|
|
real* topDiff,
|
|
real* wgtDiff,
|
|
real* wgtData,
|
|
real* biasDiff) {
|
|
// if input size changed, reset it
|
|
resetBackwardFC(bs,
|
|
ic,
|
|
ih,
|
|
iw,
|
|
botDiff,
|
|
botData,
|
|
oc,
|
|
topDiff,
|
|
wgtDiff,
|
|
wgtData,
|
|
biasDiff);
|
|
|
|
// update data
|
|
outGrad_->set_data_handle(topDiff);
|
|
|
|
stream_->submit(pipelineBwd_);
|
|
}
|
|
|
|
void MkldnnLayer::printSizeInfo() {
|
|
VLOG(DNN_SIZES) << getName() << ": bs: " << bs_ << ", ic: " << ic_
|
|
<< ", ih: " << ih_ << ", iw: " << iw_ << ", oc: " << oc_
|
|
<< ", oh: " << oh_ << ", ow: " << ow_;
|
|
}
|
|
|
|
mem::desc MkldnnLayer::createMD(mem::dims dims,
|
|
mem::format fmt,
|
|
mem::data_type type) {
|
|
// TODO(TJ): isFmtSuppoted(fmt)
|
|
return mem::desc(dims, type, fmt);
|
|
}
|
|
|
|
} // namespace paddle
|