|
|
|
@ -294,22 +294,8 @@ void MKLDNNLayer::resetMergeGrad(MKLDNNMatrixPtr& out) {
|
|
|
|
|
srcs.push_back(*src);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): remove me when mkldnn sum support different formats
|
|
|
|
|
for (size_t i = 1; i < srcPDs.size(); ++i) {
|
|
|
|
|
CHECK(srcPDs[0] == srcPDs[i]);
|
|
|
|
|
}
|
|
|
|
|
tmpOutGrad_ = out;
|
|
|
|
|
tmpCvt_ = nullptr;
|
|
|
|
|
if (out->getPrimitiveDesc() != srcPDs[0]) {
|
|
|
|
|
tmpOutGrad_ = MKLDNNMatrix::create(srcPDs[0]);
|
|
|
|
|
tmpCvt_ = MKLDNNMatrix::createReorder(tmpOutGrad_, out);
|
|
|
|
|
CHECK(tmpCvt_);
|
|
|
|
|
pipelineMergeGrad_.push_back(*tmpCvt_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto sumPD =
|
|
|
|
|
sum::primitive_desc(tmpOutGrad_->getMemoryDesc(), scales, srcPDs);
|
|
|
|
|
mergeGrad_.reset(new sum(sumPD, srcs, *tmpOutGrad_));
|
|
|
|
|
auto sumPD = sum::primitive_desc(out->getMemoryDesc(), scales, srcPDs);
|
|
|
|
|
mergeGrad_.reset(new sum(sumPD, srcs, *out));
|
|
|
|
|
pipelineMergeGrad_.insert(pipelineMergeGrad_.begin(), *mergeGrad_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|