|
|
|
@ -135,33 +135,51 @@ void MKLDNNFcLayer::reshape() {
|
|
|
|
|
|
|
|
|
|
void MKLDNNFcLayer::resetFwd() {
|
|
|
|
|
bool hasBias = biases_ && biases_->getW();
|
|
|
|
|
real* iData = getInputValue(0)->getData();
|
|
|
|
|
real* oData = getOutputValue()->getData();
|
|
|
|
|
real* wData = weight_->getW()->getData();
|
|
|
|
|
real* bData = hasBias ? biases_->getW()->getData() : NULL;
|
|
|
|
|
const MatrixPtr& in = getInputValue(0);
|
|
|
|
|
const MatrixPtr& wgt = weight_->getW();
|
|
|
|
|
const MatrixPtr& bias = hasBias ? biases_->getW() : nullptr;
|
|
|
|
|
const MatrixPtr& out = output_.value;
|
|
|
|
|
|
|
|
|
|
if (getPrev(0)->getDeviceId() == MKLDNN_DEVICE) {
|
|
|
|
|
inVal_ = std::dynamic_pointer_cast<MKLDNNMatrix>(in);
|
|
|
|
|
CHECK(inVal_) << "Input should be MKLDNNMatrix";
|
|
|
|
|
// TODO: change input nchw to nc if available
|
|
|
|
|
// inVal_->downSpatial()
|
|
|
|
|
} else {
|
|
|
|
|
inVal_ = MKLDNNMatrix::create(
|
|
|
|
|
in,
|
|
|
|
|
hasSpatial_ ? memory::dims{bs_, ic_, ih_, iw_} : memory::dims{bs_, ic_},
|
|
|
|
|
hasSpatial_ ? format::nchw : format::nc,
|
|
|
|
|
engine_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): below create should be covered in MkldnnMatrix
|
|
|
|
|
// create memory desc
|
|
|
|
|
memory::desc iMD = hasSpatial_ ? createMD({bs_, ic_, ih_, iw_}, format::nchw)
|
|
|
|
|
: createMD({bs_, ic_}, format::nc);
|
|
|
|
|
memory::desc wMD = hasSpatial_ ? createMD({oc_, ic_, ih_, iw_}, format::oihw)
|
|
|
|
|
: createMD({oc_, ic_}, format::oi);
|
|
|
|
|
memory::desc bMD = bData != NULL ? createMD({oc_}, format::x)
|
|
|
|
|
: createMD({}, format::format_undef);
|
|
|
|
|
memory::desc oMD = createMD({bs_, oc_}, format::nc);
|
|
|
|
|
wgtVal_ = MKLDNNMatrix::create(
|
|
|
|
|
wgt,
|
|
|
|
|
hasSpatial_ ? memory::dims{oc_, ic_, ih_, iw_} : memory::dims{oc_, ic_},
|
|
|
|
|
hasSpatial_ ? format::oihw : format::oi,
|
|
|
|
|
engine_);
|
|
|
|
|
|
|
|
|
|
// create memory primitive desc and memory self
|
|
|
|
|
inVal_.reset(new memory(memory::primitive_desc(iMD, engine_), iData));
|
|
|
|
|
wgtVal_.reset(new memory(memory::primitive_desc(wMD, engine_), wData));
|
|
|
|
|
outVal_.reset(new memory(memory::primitive_desc(oMD, engine_), oData));
|
|
|
|
|
biasVal_ =
|
|
|
|
|
hasBias ? MKLDNNMatrix::create(bias, {oc_}, format::x, engine_) : nullptr;
|
|
|
|
|
|
|
|
|
|
outVal_ = MKLDNNMatrix::create(out, {bs_, oc_}, format::nc, engine_);
|
|
|
|
|
|
|
|
|
|
// change original output to mkldnn output
|
|
|
|
|
output_.value = std::dynamic_pointer_cast<Matrix>(outVal_);
|
|
|
|
|
|
|
|
|
|
// create forward handle
|
|
|
|
|
prop_kind pk = prop_kind::forward;
|
|
|
|
|
fc_fwd::desc fwdDesc = bData != NULL ? fc_fwd::desc(pk, iMD, wMD, bMD, oMD)
|
|
|
|
|
: fc_fwd::desc(pk, iMD, wMD, oMD);
|
|
|
|
|
fc_fwd::desc fwdDesc =
|
|
|
|
|
hasBias ? fc_fwd::desc(pk,
|
|
|
|
|
inVal_->getMD(),
|
|
|
|
|
wgtVal_->getMD(),
|
|
|
|
|
biasVal_->getMD(),
|
|
|
|
|
outVal_->getMD())
|
|
|
|
|
: fc_fwd::desc(
|
|
|
|
|
pk, inVal_->getMD(), wgtVal_->getMD(), outVal_->getMD());
|
|
|
|
|
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
|
|
|
|
|
|
|
|
|
|
if (bData != NULL) {
|
|
|
|
|
biasVal_.reset(new memory(memory::primitive_desc(bMD, engine_), bData));
|
|
|
|
|
if (hasBias) {
|
|
|
|
|
fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *biasVal_, *outVal_));
|
|
|
|
|
} else {
|
|
|
|
|
fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *outVal_));
|
|
|
|
@ -197,7 +215,8 @@ void MKLDNNFcLayer::resetBwd() {
|
|
|
|
|
// update data
|
|
|
|
|
inVal_->set_data_handle(iData);
|
|
|
|
|
} else {
|
|
|
|
|
inVal_.reset(new memory(memory::primitive_desc(iMD, engine_), iData));
|
|
|
|
|
LOG(FATAL) << "Should not be empty";
|
|
|
|
|
// inVal_.reset(new memory(memory::primitive_desc(iMD, engine_), iData));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// create memory primitive desc and memory self
|
|
|
|
|