|
|
|
@ -52,41 +52,15 @@ public:
|
|
|
|
|
/**
|
|
|
|
|
* reset the forward primitives
|
|
|
|
|
*/
|
|
|
|
|
virtual void resetFwd(Argument& act) {
|
|
|
|
|
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
|
|
|
|
|
cnt_ = act.value->getElementCnt();
|
|
|
|
|
pipelineFwd_.clear();
|
|
|
|
|
stream_.reset(new MKLDNNStream());
|
|
|
|
|
engine_.reset(new mkldnn::engine(mkldnn::engine::cpu, 0));
|
|
|
|
|
val_ = std::dynamic_pointer_cast<MKLDNNMatrix>(act.value);
|
|
|
|
|
if (val_ == nullptr) {
|
|
|
|
|
int bs = act.getBatchSize();
|
|
|
|
|
int ih = act.getFrameHeight() > 0 ? act.getFrameHeight() : 1;
|
|
|
|
|
int iw = act.getFrameWidth() > 0 ? act.getFrameWidth() : 1;
|
|
|
|
|
int ic = cnt_ / bs / ih / iw;
|
|
|
|
|
CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw);
|
|
|
|
|
val_ = MKLDNNMatrix::create(
|
|
|
|
|
act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, *engine_);
|
|
|
|
|
CHECK(val_);
|
|
|
|
|
val_->downSpatial();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
virtual void resetFwd(Argument& act);
|
|
|
|
|
/**
|
|
|
|
|
* reset the backward primitives,
|
|
|
|
|
* can not merge this functions into resetFwd as the grad data
|
|
|
|
|
* would be changing before backward.
|
|
|
|
|
*/
|
|
|
|
|
virtual void resetBwd(Argument& act) {}
|
|
|
|
|
virtual Error __must_check forward(Argument& act) {
|
|
|
|
|
resetFwd(act);
|
|
|
|
|
stream_->submit(pipelineFwd_);
|
|
|
|
|
return Error();
|
|
|
|
|
}
|
|
|
|
|
virtual Error __must_check backward(Argument& act) {
|
|
|
|
|
resetBwd(act);
|
|
|
|
|
stream_->submit(pipelineBwd_);
|
|
|
|
|
return Error();
|
|
|
|
|
}
|
|
|
|
|
virtual Error __must_check forward(Argument& act);
|
|
|
|
|
virtual Error __must_check backward(Argument& act);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -96,6 +70,7 @@ public:
|
|
|
|
|
class MKLDNNEltwiseActivation : public MKLDNNActivation {
|
|
|
|
|
typedef mkldnn::eltwise_forward eltwise_fwd;
|
|
|
|
|
typedef mkldnn::eltwise_backward eltwise_bwd;
|
|
|
|
|
typedef mkldnn::algorithm algorithm;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
// save the forward primitive desc, which can be used backward
|
|
|
|
@ -115,68 +90,9 @@ public:
|
|
|
|
|
virtual float getAlpha() const = 0;
|
|
|
|
|
virtual float getBwdAlpha() const = 0;
|
|
|
|
|
virtual float getBeta() const { return 0.f; }
|
|
|
|
|
virtual mkldnn::algorithm getAlgo(const std::string& type) const {
|
|
|
|
|
if (type == "mkldnn_relu") {
|
|
|
|
|
return mkldnn::algorithm::eltwise_relu;
|
|
|
|
|
} else if (type == "mkldnn_tanh") {
|
|
|
|
|
return mkldnn::algorithm::eltwise_tanh;
|
|
|
|
|
} else if (type == "mkldnn_elu") {
|
|
|
|
|
return mkldnn::algorithm::eltwise_elu;
|
|
|
|
|
} else {
|
|
|
|
|
LOG(FATAL) << "Unkown eltwise activation type: " << type;
|
|
|
|
|
}
|
|
|
|
|
return (mkldnn::algorithm)0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void resetFwd(Argument& act) override {
|
|
|
|
|
if (cnt_ == act.value->getElementCnt()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
MKLDNNActivation::resetFwd(act);
|
|
|
|
|
// note: alpha represents the NegativeSlope when used in relu.
|
|
|
|
|
float alpha = getAlpha();
|
|
|
|
|
float beta = getBeta();
|
|
|
|
|
mkldnn::algorithm algo = getAlgo(this->getName());
|
|
|
|
|
auto fwdDesc = eltwise_fwd::desc(mkldnn::prop_kind::forward_training,
|
|
|
|
|
algo,
|
|
|
|
|
val_->getMemoryDesc(),
|
|
|
|
|
alpha,
|
|
|
|
|
beta);
|
|
|
|
|
fwdPD_.reset(new eltwise_fwd::primitive_desc(fwdDesc, *engine_));
|
|
|
|
|
// use inplace for forward but save input value before submit
|
|
|
|
|
inVal_ = val_;
|
|
|
|
|
copyInVal_ = nullptr;
|
|
|
|
|
if (act.grad && algo == mkldnn::algorithm::eltwise_tanh) {
|
|
|
|
|
// tanh need save src input for backward
|
|
|
|
|
inVal_ = MKLDNNMatrix::create(nullptr, val_->getPrimitiveDesc());
|
|
|
|
|
copyInVal_ = std::make_shared<mkldnn::reorder>(*val_, *inVal_);
|
|
|
|
|
CHECK(copyInVal_) << "should not be emptry";
|
|
|
|
|
pipelineFwd_.push_back(*copyInVal_);
|
|
|
|
|
}
|
|
|
|
|
fwd_.reset(new eltwise_fwd(*fwdPD_, *val_, *val_));
|
|
|
|
|
pipelineFwd_.push_back(*fwd_);
|
|
|
|
|
needResetBwd_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void resetBwd(Argument& act) override {
|
|
|
|
|
if (!needResetBwd_) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
|
|
|
|
|
needResetBwd_ = false;
|
|
|
|
|
mkldnn::algorithm algo = getAlgo(this->getName());
|
|
|
|
|
float alpha = getBwdAlpha();
|
|
|
|
|
float beta = getBeta();
|
|
|
|
|
grad_ = MKLDNNMatrix::create(act.grad, val_->getPrimitiveDesc());
|
|
|
|
|
auto eng = CPUEngine::Instance().getEngine();
|
|
|
|
|
auto bwdDesc = eltwise_bwd::desc(
|
|
|
|
|
algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta);
|
|
|
|
|
auto bwdPD = eltwise_bwd::primitive_desc(bwdDesc, eng, *fwdPD_);
|
|
|
|
|
CHECK(inVal_);
|
|
|
|
|
bwd_.reset(new eltwise_bwd(bwdPD, *inVal_, *grad_, *grad_));
|
|
|
|
|
pipelineBwd_.clear();
|
|
|
|
|
pipelineBwd_.push_back(*bwd_);
|
|
|
|
|
}
|
|
|
|
|
virtual algorithm getAlgo(std::string type) const;
|
|
|
|
|
void resetFwd(Argument& act) override;
|
|
|
|
|
void resetBwd(Argument& act) override;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -195,45 +111,9 @@ public:
|
|
|
|
|
MKLDNNSoftmaxActivation() {}
|
|
|
|
|
~MKLDNNSoftmaxActivation() {}
|
|
|
|
|
virtual const std::string& getName() const = 0;
|
|
|
|
|
void resetFwd(Argument& act) override {
|
|
|
|
|
if (cnt_ == act.value->getElementCnt()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
MKLDNNActivation::resetFwd(act);
|
|
|
|
|
int axis = 1;
|
|
|
|
|
auto fwdDesc = softmax_fwd::desc(
|
|
|
|
|
mkldnn::prop_kind::forward_scoring, val_->getMemoryDesc(), axis);
|
|
|
|
|
auto fwdPD = softmax_fwd::primitive_desc(fwdDesc, *engine_);
|
|
|
|
|
fwd_.reset(new softmax_fwd(fwdPD, *val_, *val_));
|
|
|
|
|
pipelineFwd_.push_back(*fwd_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Error __must_check backward(Argument& act) override {
|
|
|
|
|
MatrixPtr outputV = act.value;
|
|
|
|
|
MatrixPtr outputG = act.grad;
|
|
|
|
|
|
|
|
|
|
if (outputG->useGpu()) {
|
|
|
|
|
outputG->softmaxBackward(*outputV);
|
|
|
|
|
} else {
|
|
|
|
|
SetDevice device(act.deviceId);
|
|
|
|
|
Matrix::resizeOrCreate(sftMaxDot_,
|
|
|
|
|
outputG->getHeight(),
|
|
|
|
|
outputG->getWidth(),
|
|
|
|
|
/* trans */ false,
|
|
|
|
|
useGpu(act.deviceId));
|
|
|
|
|
Matrix::resizeOrCreate(sftMaxSum_,
|
|
|
|
|
outputG->getHeight(),
|
|
|
|
|
1,
|
|
|
|
|
/* trans */ false,
|
|
|
|
|
useGpu(act.deviceId));
|
|
|
|
|
|
|
|
|
|
sftMaxDot_->dotMul(*outputG, *outputV);
|
|
|
|
|
sftMaxSum_->colMerge(*sftMaxDot_);
|
|
|
|
|
|
|
|
|
|
act.grad->softmaxDerivative(*act.value, *sftMaxSum_);
|
|
|
|
|
}
|
|
|
|
|
return Error();
|
|
|
|
|
}
|
|
|
|
|
void resetFwd(Argument& act) override;
|
|
|
|
|
Error __must_check forward(Argument& act) override;
|
|
|
|
|
Error __must_check backward(Argument& act) override;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace paddle
|
|
|
|
|