|
|
|
@ -323,7 +323,7 @@ protected:
|
|
|
|
|
if (mat == nullptr) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
dnn = MKLDNNMatrix::create(mat, pd);
|
|
|
|
|
dnn = MKLDNNMatrix::create(pd, mat);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -343,7 +343,7 @@ protected:
|
|
|
|
|
in = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
|
|
|
|
|
CHECK_EQ(inputIsOnlyMKLDNN(), in != nullptr);
|
|
|
|
|
if (in == nullptr || in->getFormat() == mkldnn::memory::format::nc) {
|
|
|
|
|
in = MKLDNNMatrix::create(inMat, extPD);
|
|
|
|
|
in = MKLDNNMatrix::create(extPD, inMat);
|
|
|
|
|
}
|
|
|
|
|
extInVal_ = isPaddleFormat(in->getFormat()) ? in : nullptr;
|
|
|
|
|
if (in->getFormat() == mkldnn::memory::format::nc) {
|
|
|
|
@ -353,8 +353,8 @@ protected:
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
// need create reorder
|
|
|
|
|
in = MKLDNNMatrix::create(nullptr, *intPD);
|
|
|
|
|
extInVal_ = extInVal_ ? extInVal_ : MKLDNNMatrix::create(inMat, extPD);
|
|
|
|
|
in = MKLDNNMatrix::create(*intPD);
|
|
|
|
|
extInVal_ = extInVal_ ? extInVal_ : MKLDNNMatrix::create(extPD, inMat);
|
|
|
|
|
cvtInVal_ = MKLDNNMatrix::createReorder(extInVal_, in);
|
|
|
|
|
CHECK(cvtInVal_) << "should not be emptry";
|
|
|
|
|
}
|
|
|
|
@ -366,18 +366,18 @@ protected:
|
|
|
|
|
void resetOutValue(MKLDNNMatrixPtr& out,
|
|
|
|
|
mkldnn::memory::primitive_desc intPD) {
|
|
|
|
|
cvtOutVal_ = nullptr;
|
|
|
|
|
out = MKLDNNMatrix::create(output_.value, intPD);
|
|
|
|
|
out = MKLDNNMatrix::create(intPD, output_.value);
|
|
|
|
|
extOutVal_ = out;
|
|
|
|
|
if (outputIsOnlyMKLDNN() || isPaddleFormat(extOutVal_->getFormat())) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
// need create reorder
|
|
|
|
|
CHECK_GT(bs_ * oc_ * oh_ * ow_, 0);
|
|
|
|
|
extOutVal_ = MKLDNNMatrix::create(output_.value,
|
|
|
|
|
{bs_, oc_, oh_, ow_},
|
|
|
|
|
extOutVal_ = MKLDNNMatrix::create(mkldnn::memory::dims{bs_, oc_, oh_, ow_},
|
|
|
|
|
mkldnn::memory::format::nchw,
|
|
|
|
|
engine_);
|
|
|
|
|
out = MKLDNNMatrix::create(nullptr, intPD);
|
|
|
|
|
engine_,
|
|
|
|
|
output_.value);
|
|
|
|
|
out = MKLDNNMatrix::create(intPD);
|
|
|
|
|
cvtOutVal_ = MKLDNNMatrix::createReorder(out, extOutVal_);
|
|
|
|
|
CHECK(cvtOutVal_) << "should not be empty";
|
|
|
|
|
}
|
|
|
|
@ -402,7 +402,7 @@ protected:
|
|
|
|
|
// and the mkldnn input layer will merge them to actual prev->output_.grad
|
|
|
|
|
const MatrixPtr& inMat =
|
|
|
|
|
input->getOutputMapSize() <= 1 ? input->getOutputGrad() : nullptr;
|
|
|
|
|
in = MKLDNNMatrix::create(inMat, intPD);
|
|
|
|
|
in = MKLDNNMatrix::create(intPD, inMat);
|
|
|
|
|
Argument& arg = input->getOutput(this->getName());
|
|
|
|
|
arg.grad = std::dynamic_pointer_cast<Matrix>(in);
|
|
|
|
|
CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD)
|
|
|
|
@ -418,10 +418,10 @@ protected:
|
|
|
|
|
// need create reorder
|
|
|
|
|
CHECK(extInVal_ != nullptr && isPaddleFormat(extInVal_->getFormat()))
|
|
|
|
|
<< "should have external input value and the format must be nchw(nc)";
|
|
|
|
|
extInGrad_ = MKLDNNMatrix::create(inMat, extInVal_->getPrimitiveDesc());
|
|
|
|
|
extInGrad_ = MKLDNNMatrix::create(extInVal_->getPrimitiveDesc(), inMat);
|
|
|
|
|
CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD)
|
|
|
|
|
<< "should have internal input value and primitive desc must equal";
|
|
|
|
|
in = MKLDNNMatrix::create(nullptr, intPD);
|
|
|
|
|
in = MKLDNNMatrix::create(intPD);
|
|
|
|
|
cvtInGrad_ = MKLDNNMatrix::createReorder(in, extInGrad_);
|
|
|
|
|
CHECK(cvtInGrad_);
|
|
|
|
|
}
|
|
|
|
@ -440,7 +440,7 @@ protected:
|
|
|
|
|
extOutGrad_ = nullptr;
|
|
|
|
|
out = nullptr;
|
|
|
|
|
MatrixPtr& outMat = output_.grad;
|
|
|
|
|
out = MKLDNNMatrix::create(outMat, intPD);
|
|
|
|
|
out = MKLDNNMatrix::create(intPD, outMat);
|
|
|
|
|
resetMergeGrad(out);
|
|
|
|
|
if (outputIsOnlyMKLDNN()) {
|
|
|
|
|
return;
|
|
|
|
@ -453,10 +453,10 @@ protected:
|
|
|
|
|
// need create reorder
|
|
|
|
|
CHECK(extOutVal_ != nullptr && isPaddleFormat(extOutVal_->getFormat()))
|
|
|
|
|
<< "should have external output value and the format must be nchw(nc)";
|
|
|
|
|
extOutGrad_ = MKLDNNMatrix::create(outMat, extOutVal_->getPrimitiveDesc());
|
|
|
|
|
extOutGrad_ = MKLDNNMatrix::create(extOutVal_->getPrimitiveDesc(), outMat);
|
|
|
|
|
CHECK(outVal_ != nullptr && outVal_->getPrimitiveDesc() == intPD)
|
|
|
|
|
<< "should have internal output value and primitive desc must equal";
|
|
|
|
|
out = MKLDNNMatrix::create(nullptr, intPD);
|
|
|
|
|
out = MKLDNNMatrix::create(intPD);
|
|
|
|
|
cvtOutGrad_ = MKLDNNMatrix::createReorder(extOutGrad_, out);
|
|
|
|
|
CHECK(cvtOutGrad_);
|
|
|
|
|
}
|
|
|
|
@ -499,7 +499,7 @@ protected:
|
|
|
|
|
tmpOutGrad_ = out;
|
|
|
|
|
tmpCvt_ = nullptr;
|
|
|
|
|
if (out->getPrimitiveDesc() != srcPDs[0]) {
|
|
|
|
|
tmpOutGrad_ = MKLDNNMatrix::create(nullptr, srcPDs[0]);
|
|
|
|
|
tmpOutGrad_ = MKLDNNMatrix::create(srcPDs[0]);
|
|
|
|
|
tmpCvt_ = MKLDNNMatrix::createReorder(tmpOutGrad_, out);
|
|
|
|
|
CHECK(tmpCvt_);
|
|
|
|
|
pipelineMergeGrad_.push_back(*tmpCvt_);
|
|
|
|
|