|
|
|
@ -14,7 +14,6 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "MKLDNNFcLayer.h"
|
|
|
|
|
#include "paddle/utils/Logging.h"
|
|
|
|
|
#include "paddle/utils/Stat.h"
|
|
|
|
|
|
|
|
|
|
using namespace mkldnn; // NOLINT
|
|
|
|
|
typedef memory::format format;
|
|
|
|
@ -40,6 +39,8 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap,
|
|
|
|
|
oc_ = getSize();
|
|
|
|
|
oh_ = 1;
|
|
|
|
|
ow_ = 1;
|
|
|
|
|
ih_ = 1;
|
|
|
|
|
iw_ = 1;
|
|
|
|
|
|
|
|
|
|
// input size can not change in FC
|
|
|
|
|
iLayerSize_ = inputLayers_[0]->getSize();
|
|
|
|
@ -78,36 +79,17 @@ void MKLDNNFcLayer::convertWeightsToPaddle() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNFcLayer::reshape() {
|
|
|
|
|
const Argument& input = getInput(0, getPrev(0)->getDeviceId());
|
|
|
|
|
int batchSize = input.getBatchSize();
|
|
|
|
|
if (bs_ == batchSize) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
bs_ = batchSize;
|
|
|
|
|
ih_ = input.getFrameHeight();
|
|
|
|
|
iw_ = input.getFrameWidth();
|
|
|
|
|
if (ih_ == 0) {
|
|
|
|
|
ih_ = 1;
|
|
|
|
|
}
|
|
|
|
|
if (iw_ == 0) {
|
|
|
|
|
iw_ = 1;
|
|
|
|
|
}
|
|
|
|
|
reshapeInput();
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(iLayerSize_, inputLayers_[0]->getSize());
|
|
|
|
|
ic_ = iLayerSize_ / (ih_ * iw_);
|
|
|
|
|
CHECK_EQ(size_t(ic_ * ih_ * iw_), iLayerSize_) << "not divisible";
|
|
|
|
|
CHECK_EQ(size_t(oc_), getSize());
|
|
|
|
|
printSizeInfo();
|
|
|
|
|
|
|
|
|
|
// reset output
|
|
|
|
|
output_.setFrameHeight(oh_);
|
|
|
|
|
output_.setFrameWidth(ow_);
|
|
|
|
|
resetOutput(bs_, oc_);
|
|
|
|
|
reshapeOutput(oh_, ow_);
|
|
|
|
|
resizeOutput(bs_, oc_);
|
|
|
|
|
|
|
|
|
|
// reset mkldnn forward
|
|
|
|
|
resetFwd();
|
|
|
|
|
needResetBwd_ = true;
|
|
|
|
|
|
|
|
|
|
convertWeightsFromPaddle();
|
|
|
|
|
printSizeInfo();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNFcLayer::resetFwd() {
|
|
|
|
@ -137,7 +119,6 @@ void MKLDNNFcLayer::resetFwd() {
|
|
|
|
|
// change original output value to mkldnn output value
|
|
|
|
|
output_.value = std::dynamic_pointer_cast<Matrix>(outVal_);
|
|
|
|
|
if (!outputIsOnlyMKLDNN()) {
|
|
|
|
|
copyOutputInfoToOtherDevice();
|
|
|
|
|
// fc cpu output value do not need create convert
|
|
|
|
|
// just share point
|
|
|
|
|
getOutput(CPU_DEVICE).value->setData(output_.value->getData());
|
|
|
|
@ -243,51 +224,13 @@ void MKLDNNFcLayer::resetBwd() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNFcLayer::updateInputData() {
|
|
|
|
|
if (inputLayers_[0]->getType() != "data") {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
real* iData = getInputValue(0, CPU_DEVICE)->getData();
|
|
|
|
|
inVal_->setData(iData);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNFcLayer::forward(PassType passType) {
|
|
|
|
|
Layer::forward(passType);
|
|
|
|
|
reshape();
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str());
|
|
|
|
|
updateInputData();
|
|
|
|
|
|
|
|
|
|
// just submit forward pipeline
|
|
|
|
|
stream_->submit(pipelineFwd_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/* activation */ {
|
|
|
|
|
REGISTER_TIMER_INFO("FwActTimer", getName().c_str());
|
|
|
|
|
forwardActivation();
|
|
|
|
|
}
|
|
|
|
|
inVal_->setData(getInputValue(0, CPU_DEVICE)->getData());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNFcLayer::backward(const UpdateCallback& callback) {
|
|
|
|
|
/* Do derivation */ {
|
|
|
|
|
REGISTER_TIMER_INFO("BpActTimer", getName().c_str());
|
|
|
|
|
backwardActivation();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str());
|
|
|
|
|
resetBwd();
|
|
|
|
|
|
|
|
|
|
// just sumbmit backward pipeline
|
|
|
|
|
stream_->submit(pipelineBwd_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
|
|
|
|
|
weight_->getParameterPtr()->incUpdate(callback);
|
|
|
|
|
if (biases_ && biases_->getWGrad()) {
|
|
|
|
|
biases_->getParameterPtr()->incUpdate(callback);
|
|
|
|
|
}
|
|
|
|
|
void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) {
|
|
|
|
|
weight_->getParameterPtr()->incUpdate(callback);
|
|
|
|
|
if (biases_ && biases_->getWGrad()) {
|
|
|
|
|
biases_->getParameterPtr()->incUpdate(callback);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace paddle
|
|
|
|
|