|
|
|
|
@ -58,25 +58,21 @@ void MKLDNNAddtoLayer::reshape(
|
|
|
|
|
|
|
|
|
|
void MKLDNNAddtoLayer::resetFwd(std::vector<primitive>& pipeline,
|
|
|
|
|
MKLDNNMatrixPtr& in,
|
|
|
|
|
MKLDNNMatrixPtr& wgt,
|
|
|
|
|
MKLDNNMatrixPtr& bias,
|
|
|
|
|
MKLDNNMatrixPtr& out) {
|
|
|
|
|
resetFwdBuffers(inVals_, bias, out);
|
|
|
|
|
resetFwdBuffers(inVals_, biasVal_, out);
|
|
|
|
|
in = inVals_[0];
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<sum::primitive_desc> fwdPD;
|
|
|
|
|
std::shared_ptr<sum::primitive_desc> biasPD;
|
|
|
|
|
resetFwdPD(fwdPD, biasPD, inVals_, bias, out);
|
|
|
|
|
resetFwdPD(fwdPD, biasPD, inVals_, biasVal_, out);
|
|
|
|
|
|
|
|
|
|
resetFwdPipeline(pipeline, fwdPD, biasPD, inVals_, bias, out);
|
|
|
|
|
resetFwdPipeline(pipeline, fwdPD, biasPD, inVals_, biasVal_, out);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNAddtoLayer::resetBwd(std::vector<primitive>& pipeline,
|
|
|
|
|
MKLDNNMatrixPtr& in,
|
|
|
|
|
MKLDNNMatrixPtr& wgt,
|
|
|
|
|
MKLDNNMatrixPtr& bias,
|
|
|
|
|
MKLDNNMatrixPtr& out) {
|
|
|
|
|
resetBwdBuffers(inGrads_, bias, out);
|
|
|
|
|
resetBwdBuffers(inGrads_, biasGrad_, out);
|
|
|
|
|
in = inGrads_[0];
|
|
|
|
|
|
|
|
|
|
// backward only need share output grad to input grad
|
|
|
|
|
@ -89,15 +85,17 @@ void MKLDNNAddtoLayer::resetBwd(std::vector<primitive>& pipeline,
|
|
|
|
|
|
|
|
|
|
// backward bias
|
|
|
|
|
bwdBias_ = nullptr;
|
|
|
|
|
if (bias) {
|
|
|
|
|
if (biasGrad_) {
|
|
|
|
|
std::vector<float> scales(bs_, 1.0);
|
|
|
|
|
std::vector<memory::primitive_desc> srcPDs(bs_, bias->getPrimitiveDesc());
|
|
|
|
|
auto biasPD = sum::primitive_desc(bias->getMemoryDesc(), scales, srcPDs);
|
|
|
|
|
std::vector<memory::primitive_desc> srcPDs(bs_,
|
|
|
|
|
biasGrad_->getPrimitiveDesc());
|
|
|
|
|
auto biasPD =
|
|
|
|
|
sum::primitive_desc(biasGrad_->getMemoryDesc(), scales, srcPDs);
|
|
|
|
|
std::vector<primitive::at> srcs;
|
|
|
|
|
for (size_t i = 0; i < grads_.size(); ++i) {
|
|
|
|
|
srcs.push_back(*(grads_[i]));
|
|
|
|
|
}
|
|
|
|
|
bwdBias_.reset(new sum(biasPD, srcs, *bias));
|
|
|
|
|
bwdBias_.reset(new sum(biasPD, srcs, *biasGrad_));
|
|
|
|
|
pipeline.push_back(*bwdBias_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|