parent
f6a940936b
commit
b2bd67133a
@ -1,222 +0,0 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "MkldnnLayer.h"
|
||||
|
||||
using mem = mkldnn::memory; // NOLINT
|
||||
typedef mem::format format;
|
||||
typedef mkldnn::inner_product_forward fc_fwd;
|
||||
typedef mkldnn::inner_product_backward_weights fc_bwdWgt;
|
||||
typedef mkldnn::inner_product_backward_data fc_bwdData;
|
||||
|
||||
namespace paddle {
|
||||
|
||||
bool MkldnnLayer::init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) {
|
||||
if (!Layer::init(layerMap, parameterMap)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn."
|
||||
<< "Please set WITH_MKLDNN=ON "
|
||||
<< "and set use_mkldnn=True";
|
||||
stream_.reset(new MkldnnStream());
|
||||
engine_ = CpuEngine::Instance().getEngine();
|
||||
|
||||
// TODO(TJ): deivecId
|
||||
return true;
|
||||
}
|
||||
|
||||
void MkldnnLayer::resetForwardFC(int bs,
|
||||
int ic,
|
||||
int ih,
|
||||
int iw,
|
||||
real* botData,
|
||||
int oc,
|
||||
real* topData,
|
||||
real* wgtData,
|
||||
real* biasData) {
|
||||
bool hasSpatial = ih == 1 && iw == 1 ? false : true;
|
||||
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
|
||||
: createMD({bs, ic}, format::nc);
|
||||
mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw)
|
||||
: createMD({oc, ic}, format::oi);
|
||||
mem::desc biasMD = biasData != NULL ? createMD({oc}, format::x)
|
||||
: createMD({}, format::format_undef);
|
||||
mem::desc topMD = createMD({bs, oc}, format::nc);
|
||||
|
||||
mem::primitive_desc botPD = mem::primitive_desc(botMD, engine_);
|
||||
if (inVal_ && inVal_->get_primitive_desc() == botPD) {
|
||||
return;
|
||||
}
|
||||
|
||||
inVal_.reset(new mem(botPD, botData));
|
||||
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
|
||||
outVal_.reset(new mem(mem::primitive_desc(topMD, engine_), topData));
|
||||
|
||||
mkldnn::prop_kind pk = mkldnn::prop_kind::forward;
|
||||
fc_fwd::desc fwdDesc = biasData != NULL
|
||||
? fc_fwd::desc(pk, botMD, wgtMD, biasMD, topMD)
|
||||
: fc_fwd::desc(pk, botMD, wgtMD, topMD);
|
||||
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
|
||||
|
||||
if (biasData != NULL) {
|
||||
biasVal_.reset(new mem(mem::primitive_desc(biasMD, engine_), biasData));
|
||||
fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *biasVal_, *outVal_));
|
||||
} else {
|
||||
fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *outVal_));
|
||||
}
|
||||
pipelineFwd_.clear();
|
||||
pipelineFwd_.push_back(*fwd_);
|
||||
}
|
||||
|
||||
void MkldnnLayer::mkldnnForwardFC(int bs,
|
||||
int ic,
|
||||
int ih,
|
||||
int iw,
|
||||
real* botData,
|
||||
int oc,
|
||||
real* topData,
|
||||
real* wgtData,
|
||||
real* biasData) {
|
||||
// if input size changed, reset it
|
||||
resetForwardFC(bs, ic, ih, iw, botData, oc, topData, wgtData, biasData);
|
||||
|
||||
this->convertWeightsFromPaddle();
|
||||
|
||||
// update input, since the data might be changed if this is after data layer
|
||||
inVal_->set_data_handle(botData);
|
||||
|
||||
// just forward
|
||||
stream_->submit(pipelineFwd_);
|
||||
}
|
||||
|
||||
void MkldnnLayer::resetBackwardFC(int bs,
|
||||
int ic,
|
||||
int ih,
|
||||
int iw,
|
||||
real* botDiff,
|
||||
real* botData,
|
||||
int oc,
|
||||
real* topDiff,
|
||||
real* wgtDiff,
|
||||
real* wgtData,
|
||||
real* biasDiff) {
|
||||
bool hasSpatial = ih == 1 && iw == 1 ? false : true;
|
||||
|
||||
// backward weight
|
||||
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
|
||||
: createMD({bs, ic}, format::nc);
|
||||
mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw)
|
||||
: createMD({oc, ic}, format::oi);
|
||||
mem::desc topMD = createMD({bs, oc}, format::nc);
|
||||
mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x)
|
||||
: createMD({}, format::format_undef);
|
||||
|
||||
mem::primitive_desc topPD = mem::primitive_desc(botMD, engine_);
|
||||
if (outGrad_ && outGrad_->get_primitive_desc() == topPD) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (inVal_) {
|
||||
// update data
|
||||
inVal_->set_data_handle(botData);
|
||||
} else {
|
||||
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
|
||||
}
|
||||
wgtGrad_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtDiff));
|
||||
outGrad_.reset(new mem(topPD, topDiff));
|
||||
|
||||
fc_fwd::desc fwdDesc =
|
||||
fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD);
|
||||
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
|
||||
fc_bwdWgt::desc bwdWgtDesc =
|
||||
biasDiff != NULL ? fc_bwdWgt::desc(botMD, wgtMD, biasMD, topMD)
|
||||
: fc_bwdWgt::desc(botMD, wgtMD, topMD);
|
||||
fc_bwdWgt::primitive_desc bwdWgtPD =
|
||||
fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD);
|
||||
|
||||
if (biasDiff != NULL) {
|
||||
biasGrad_.reset(new mem(mem::primitive_desc(biasMD, engine_), biasDiff));
|
||||
bwdWgt_.reset(
|
||||
new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_, *biasGrad_));
|
||||
} else {
|
||||
bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_));
|
||||
}
|
||||
pipelineBwd_.clear();
|
||||
pipelineBwd_.push_back(*bwdWgt_);
|
||||
|
||||
// backward data
|
||||
if (botDiff == NULL) {
|
||||
return;
|
||||
}
|
||||
|
||||
fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(botMD, wgtMD, topMD);
|
||||
fc_bwdData::primitive_desc bwdDataPD =
|
||||
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
|
||||
inGrad_.reset(new mem(mem::primitive_desc(botMD, engine_), botDiff));
|
||||
if (wgtVal_) {
|
||||
// update data
|
||||
wgtVal_->set_data_handle(wgtData);
|
||||
} else {
|
||||
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
|
||||
}
|
||||
bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_));
|
||||
pipelineBwd_.push_back(*bwdData_);
|
||||
}
|
||||
|
||||
void MkldnnLayer::mkldnnBackwardFC(int bs,
|
||||
int ic,
|
||||
int ih,
|
||||
int iw,
|
||||
real* botDiff,
|
||||
real* botData,
|
||||
int oc,
|
||||
real* topDiff,
|
||||
real* wgtDiff,
|
||||
real* wgtData,
|
||||
real* biasDiff) {
|
||||
// if input size changed, reset it
|
||||
resetBackwardFC(bs,
|
||||
ic,
|
||||
ih,
|
||||
iw,
|
||||
botDiff,
|
||||
botData,
|
||||
oc,
|
||||
topDiff,
|
||||
wgtDiff,
|
||||
wgtData,
|
||||
biasDiff);
|
||||
|
||||
// update data
|
||||
outGrad_->set_data_handle(topDiff);
|
||||
|
||||
stream_->submit(pipelineBwd_);
|
||||
}
|
||||
|
||||
void MkldnnLayer::printSizeInfo() {
|
||||
VLOG(DNN_SIZES) << getName() << ": bs: " << bs_ << ", ic: " << ic_
|
||||
<< ", ih: " << ih_ << ", iw: " << iw_ << ", oc: " << oc_
|
||||
<< ", oh: " << oh_ << ", ow: " << ow_;
|
||||
}
|
||||
|
||||
mem::desc MkldnnLayer::createMD(mem::dims dims,
|
||||
mem::format fmt,
|
||||
mem::data_type type) {
|
||||
// TODO(TJ): isFmtSuppoted(fmt)
|
||||
return mem::desc(dims, type, fmt);
|
||||
}
|
||||
|
||||
} // namespace paddle
|
Loading…
Reference in new issue