|
|
|
@ -220,13 +220,12 @@ void MKLDNNFcLayer::resetBwd() {
|
|
|
|
|
pipelineBwd_.push_back(*bwdWgt_);
|
|
|
|
|
|
|
|
|
|
/// backward data
|
|
|
|
|
device = inputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE;
|
|
|
|
|
const MatrixPtr& in = getInputGrad(0, device);
|
|
|
|
|
const MatrixPtr& in = inputLayers_[0]->getOutput().grad;
|
|
|
|
|
if (in == nullptr) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (getInput(0, device).getAllCount() > 1) {
|
|
|
|
|
// TODO(TJ): use outputMaps_ ways when merge outgrad done
|
|
|
|
|
if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) {
|
|
|
|
|
// TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done
|
|
|
|
|
} else {
|
|
|
|
|
inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc());
|
|
|
|
|
}
|
|
|
|
@ -243,13 +242,21 @@ void MKLDNNFcLayer::resetBwd() {
|
|
|
|
|
pipelineBwd_.push_back(*bwdData_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNFcLayer::updateInputData() {
|
|
|
|
|
if (inputLayers_[0]->getType() != "data") {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
real* iData = getInputValue(0, CPU_DEVICE)->getData();
|
|
|
|
|
inVal_->setData(iData);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNFcLayer::forward(PassType passType) {
|
|
|
|
|
Layer::forward(passType);
|
|
|
|
|
reshape();
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str());
|
|
|
|
|
syncInputValue();
|
|
|
|
|
updateInputData();
|
|
|
|
|
|
|
|
|
|
// just submit forward pipeline
|
|
|
|
|
stream_->submit(pipelineFwd_);
|
|
|
|
@ -271,7 +278,6 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) {
|
|
|
|
|
REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str());
|
|
|
|
|
resetBwd();
|
|
|
|
|
|
|
|
|
|
syncOutputGrad();
|
|
|
|
|
// just sumbmit backward pipeline
|
|
|
|
|
stream_->submit(pipelineBwd_);
|
|
|
|
|
}
|
|
|
|
|