|
|
|
@ -36,6 +36,7 @@ protected:
|
|
|
|
|
// mkldnn matrix, primitive, stream and pipeline
|
|
|
|
|
MKLDNNMatrixPtr val_;
|
|
|
|
|
MKLDNNMatrixPtr grad_;
|
|
|
|
|
std::shared_ptr<mkldnn::engine> engine_;
|
|
|
|
|
std::shared_ptr<MKLDNNStream> stream_;
|
|
|
|
|
std::shared_ptr<mkldnn::primitive> fwd_;
|
|
|
|
|
std::shared_ptr<mkldnn::primitive> bwd_;
|
|
|
|
@ -48,8 +49,44 @@ public:
|
|
|
|
|
static ActivationFunction* create(const std::string& type);
|
|
|
|
|
static std::vector<std::string> getAllRegisteredTypes();
|
|
|
|
|
virtual const std::string& getName() const = 0;
|
|
|
|
|
virtual Error __must_check forward(Argument& act) = 0;
|
|
|
|
|
virtual Error __must_check backward(Argument& act) = 0;
|
|
|
|
|
/**
|
|
|
|
|
* reset the forward primitives
|
|
|
|
|
*/
|
|
|
|
|
virtual void resetFwd(Argument& act) {
|
|
|
|
|
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
|
|
|
|
|
cnt_ = act.value->getElementCnt();
|
|
|
|
|
pipelineFwd_.clear();
|
|
|
|
|
stream_.reset(new MKLDNNStream());
|
|
|
|
|
engine_.reset(new mkldnn::engine(mkldnn::engine::cpu, 0));
|
|
|
|
|
val_ = std::dynamic_pointer_cast<MKLDNNMatrix>(act.value);
|
|
|
|
|
if (val_ == nullptr) {
|
|
|
|
|
int bs = act.getBatchSize();
|
|
|
|
|
int ih = act.getFrameHeight() > 0 ? act.getFrameHeight() : 1;
|
|
|
|
|
int iw = act.getFrameWidth() > 0 ? act.getFrameWidth() : 1;
|
|
|
|
|
int ic = cnt_ / bs / ih / iw;
|
|
|
|
|
CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw);
|
|
|
|
|
val_ = MKLDNNMatrix::create(
|
|
|
|
|
act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, *engine_);
|
|
|
|
|
CHECK(val_);
|
|
|
|
|
val_->downSpatial();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
/**
|
|
|
|
|
* reset the backward primitives,
|
|
|
|
|
* can not merge this functions into resetFwd as the grad data
|
|
|
|
|
* would be changing before backward.
|
|
|
|
|
*/
|
|
|
|
|
virtual void resetBwd(Argument& act) {}
|
|
|
|
|
virtual Error __must_check forward(Argument& act) {
|
|
|
|
|
resetFwd(act);
|
|
|
|
|
stream_->submit(pipelineFwd_);
|
|
|
|
|
return Error();
|
|
|
|
|
}
|
|
|
|
|
virtual Error __must_check backward(Argument& act) {
|
|
|
|
|
resetBwd(act);
|
|
|
|
|
stream_->submit(pipelineBwd_);
|
|
|
|
|
return Error();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -70,9 +107,7 @@ protected:
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
MKLDNNEltwiseActivation() {}
|
|
|
|
|
|
|
|
|
|
~MKLDNNEltwiseActivation() {}
|
|
|
|
|
|
|
|
|
|
virtual const std::string& getName() const = 0;
|
|
|
|
|
|
|
|
|
|
// in common, the alpha of forward and backward should be equal.
|
|
|
|
@ -93,42 +128,21 @@ public:
|
|
|
|
|
return (mkldnn::algorithm)0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* reshape and reset the forward primitives
|
|
|
|
|
*/
|
|
|
|
|
void resetFwd(Argument& act) {
|
|
|
|
|
void resetFwd(Argument& act) override {
|
|
|
|
|
if (cnt_ == act.value->getElementCnt()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
|
|
|
|
|
cnt_ = act.value->getElementCnt();
|
|
|
|
|
stream_.reset(new MKLDNNStream());
|
|
|
|
|
auto eng = CPUEngine::Instance().getEngine();
|
|
|
|
|
|
|
|
|
|
// get algo setting
|
|
|
|
|
mkldnn::algorithm algo = getAlgo(this->getName());
|
|
|
|
|
MKLDNNActivation::resetFwd(act);
|
|
|
|
|
// note: alpha represents the NegativeSlope when used in relu.
|
|
|
|
|
float alpha = getAlpha();
|
|
|
|
|
float beta = getBeta();
|
|
|
|
|
|
|
|
|
|
pipelineFwd_.clear();
|
|
|
|
|
val_ = std::dynamic_pointer_cast<MKLDNNMatrix>(act.value);
|
|
|
|
|
if (val_ == nullptr) {
|
|
|
|
|
int bs = act.getBatchSize();
|
|
|
|
|
int ih = act.getFrameHeight() > 0 ? act.getFrameHeight() : 1;
|
|
|
|
|
int iw = act.getFrameWidth() > 0 ? act.getFrameWidth() : 1;
|
|
|
|
|
int ic = cnt_ / bs / ih / iw;
|
|
|
|
|
CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw);
|
|
|
|
|
val_ = MKLDNNMatrix::create(
|
|
|
|
|
act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, eng);
|
|
|
|
|
CHECK(val_);
|
|
|
|
|
}
|
|
|
|
|
mkldnn::algorithm algo = getAlgo(this->getName());
|
|
|
|
|
auto fwdDesc = eltwise_fwd::desc(mkldnn::prop_kind::forward_training,
|
|
|
|
|
algo,
|
|
|
|
|
val_->getMemoryDesc(),
|
|
|
|
|
alpha,
|
|
|
|
|
beta);
|
|
|
|
|
fwdPD_.reset(new eltwise_fwd::primitive_desc(fwdDesc, eng));
|
|
|
|
|
fwdPD_.reset(new eltwise_fwd::primitive_desc(fwdDesc, *engine_));
|
|
|
|
|
// use inplace for forward but save input value before submit
|
|
|
|
|
inVal_ = val_;
|
|
|
|
|
copyInVal_ = nullptr;
|
|
|
|
@ -144,11 +158,7 @@ public:
|
|
|
|
|
needResetBwd_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* reset the backward primitives, can not merge into resetFwd as the grad data
|
|
|
|
|
* would be changing before backward.
|
|
|
|
|
*/
|
|
|
|
|
void resetBwd(Argument& act) {
|
|
|
|
|
void resetBwd(Argument& act) override {
|
|
|
|
|
if (!needResetBwd_) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -167,16 +177,61 @@ public:
|
|
|
|
|
pipelineBwd_.clear();
|
|
|
|
|
pipelineBwd_.push_back(*bwd_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Error __must_check forward(Argument& act) {
|
|
|
|
|
resetFwd(act);
|
|
|
|
|
stream_->submit(pipelineFwd_);
|
|
|
|
|
return Error();
|
|
|
|
|
/**
|
|
|
|
|
* @brief Base class of MKLDNN softmax Activation,
|
|
|
|
|
* only have mkldnn forward, use cpu implement for backward.
|
|
|
|
|
*/
|
|
|
|
|
class MKLDNNSoftmaxActivation : public MKLDNNActivation {
|
|
|
|
|
typedef mkldnn::softmax_forward softmax_fwd;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// for backward
|
|
|
|
|
MatrixPtr sftMaxSum_;
|
|
|
|
|
MatrixPtr sftMaxDot_;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
MKLDNNSoftmaxActivation() {}
|
|
|
|
|
~MKLDNNSoftmaxActivation() {}
|
|
|
|
|
virtual const std::string& getName() const = 0;
|
|
|
|
|
void resetFwd(Argument& act) override {
|
|
|
|
|
if (cnt_ == act.value->getElementCnt()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
MKLDNNActivation::resetFwd(act);
|
|
|
|
|
int axis = 1;
|
|
|
|
|
auto fwdDesc = softmax_fwd::desc(
|
|
|
|
|
mkldnn::prop_kind::forward_scoring, val_->getMemoryDesc(), axis);
|
|
|
|
|
auto fwdPD = softmax_fwd::primitive_desc(fwdDesc, *engine_);
|
|
|
|
|
fwd_.reset(new softmax_fwd(fwdPD, *val_, *val_));
|
|
|
|
|
pipelineFwd_.push_back(*fwd_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Error __must_check backward(Argument& act) {
|
|
|
|
|
resetBwd(act);
|
|
|
|
|
stream_->submit(pipelineBwd_);
|
|
|
|
|
Error __must_check backward(Argument& act) override {
|
|
|
|
|
MatrixPtr outputV = act.value;
|
|
|
|
|
MatrixPtr outputG = act.grad;
|
|
|
|
|
|
|
|
|
|
if (outputG->useGpu()) {
|
|
|
|
|
outputG->softmaxBackward(*outputV);
|
|
|
|
|
} else {
|
|
|
|
|
SetDevice device(act.deviceId);
|
|
|
|
|
Matrix::resizeOrCreate(sftMaxDot_,
|
|
|
|
|
outputG->getHeight(),
|
|
|
|
|
outputG->getWidth(),
|
|
|
|
|
/* trans */ false,
|
|
|
|
|
useGpu(act.deviceId));
|
|
|
|
|
Matrix::resizeOrCreate(sftMaxSum_,
|
|
|
|
|
outputG->getHeight(),
|
|
|
|
|
1,
|
|
|
|
|
/* trans */ false,
|
|
|
|
|
useGpu(act.deviceId));
|
|
|
|
|
|
|
|
|
|
sftMaxDot_->dotMul(*outputG, *outputV);
|
|
|
|
|
sftMaxSum_->colMerge(*sftMaxDot_);
|
|
|
|
|
|
|
|
|
|
act.grad->softmaxDerivative(*act.value, *sftMaxSum_);
|
|
|
|
|
}
|
|
|
|
|
return Error();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|