diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index a5555c4618..ad50c15a7d 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -61,39 +61,20 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() { return; } - // TODO(TJ): dst format should get from wgtVal_ - int dstFmt = PARAM_FORMAT_MKLDNN_OI; - int srcFmt = weight_->getParameterPtr()->getHeaderFormat(); - if (srcFmt == dstFmt) { - return; - } - - // The weight_ is transposed from initial paddle weight - MatrixPtr paddleWgt = Matrix::create( - weight_->getW()->getData(), iLayerSize_, oc_, false, false); - - // TODO(TJ): remove this print when do not need differ weights - std::ostringstream ostr; - paddleWgt->print(ostr); - VLOG(MKLDNN_ALL) << "Initial Weight from paddle: " << std::endl << ostr.str(); - - // The mkldnn weight is transposed from initial paddle matrix - MatrixPtr paddleWgtT; - paddleWgt->transpose(paddleWgtT, true); - weight_->getW()->copyFrom(*paddleWgtT); - weight_->getParameterPtr()->setHeaderFormat(dstFmt); + CHECK(wgtVal_) << "should have been initialized"; + bool hasNoSpatial_ = ih_ == 1 && iw_ == 1; + auto targetDim = wgtVal_->getDims(); + auto srcFmt = hasNoSpatial_ ? memory::format::io : memory::format::ihwo; + wgtVal_->reorderDataFrom(wgtVal_, srcFmt, targetDim); hasInitedWgt_ = true; } void MKLDNNFcLayer::convertWeightsToPaddle() { - MatrixPtr dnnWgt = weight_->getW(); - MatrixPtr paddleWgt; - dnnWgt->transpose(paddleWgt, true); - - // copy paddle weight and override on weight_ - MatrixPtr dnnWgtT = Matrix::create( - dnnWgt->getData(), dnnWgt->getWidth(), dnnWgt->getHeight(), false, false); - dnnWgtT->copyFrom(*paddleWgt); + CHECK(wgtVal_) << "should have been initialized"; + bool hasNoSpatial_ = ih_ == 1 && iw_ == 1; + auto targetDim = wgtVal_->getDims(); + auto dstFmt = hasNoSpatial_ ? memory::format::io : memory::format::ihwo; + wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim); } void MKLDNNFcLayer::reshape() { diff --git a/paddle/math/MKLDNNMatrix.cpp b/paddle/math/MKLDNNMatrix.cpp index 94df9c1550..32ae3b1bcf 100644 --- a/paddle/math/MKLDNNMatrix.cpp +++ b/paddle/math/MKLDNNMatrix.cpp @@ -56,6 +56,63 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, return create(m, pd); } +void MKLDNNMatrix::reorderDataFrom(const MKLDNNMatrixPtr& m, + memory::format srcFmt, + memory::dims targetDim) { + memory::format dstFmt = getFormat(); + if (srcFmt == dstFmt) { + return; + } + CHECK_EQ(getElementCnt(), m->getElementCnt()) << "size should equal"; + real* srcData = getData(); + real* dstData = m->getData(); + reorderOnce(srcData, dstData, srcFmt, dstFmt, targetDim); +} + +void MKLDNNMatrix::reorderDataTo(const MKLDNNMatrixPtr& m, + memory::format dstFmt, + memory::dims targetDim) { + memory::format srcFmt = getFormat(); + if (srcFmt == dstFmt) { + return; + } + CHECK_EQ(getElementCnt(), m->getElementCnt()) << "size should equal"; + real* srcData = getData(); + real* dstData = m->getData(); + reorderOnce(srcData, dstData, srcFmt, dstFmt, targetDim); +} + +void MKLDNNMatrix::reorderOnce(void* srcData, + void* dstData, + memory::format srcFmt, + memory::format dstFmt, + memory::dims dm) { + CHECK(srcData); + CHECK(dstData); + MatrixPtr tmpSrc; + if (dstData == srcData) { + // inplace data + size_t sz = 1; + for (size_t i = 0; i < dm.size(); ++i) { + sz *= dm[i]; + } + tmpSrc = Matrix::create(sz, 1, false, false); + tmpSrc->copyFrom((real*)srcData, sz); + srcData = tmpSrc->getData(); + } + + auto dtype = this->getDtype(); + auto srcMD = memory::desc(dm, dtype, srcFmt); + auto dstMD = memory::desc(dm, dtype, dstFmt); + + auto eg = this->getEngine(); + auto src = memory(memory::primitive_desc(srcMD, eg), srcData); + auto dst = memory(memory::primitive_desc(dstMD, eg), dstData); + + auto r = reorder(src, dst); + stream(stream::kind::eager).submit({r}).wait(); +} + void MKLDNNMatrix::downSpatial() { int fmt = getFormat(); if (!(fmt == memory::format::nchw || fmt == memory::format::oihw)) { diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h index 05adc867c2..ea3fd7d461 100644 --- a/paddle/math/MKLDNNMatrix.h +++ b/paddle/math/MKLDNNMatrix.h @@ -21,9 +21,6 @@ limitations under the License. */ namespace paddle { -static const std::map PARAM_FOARMAT_MAP = - {{mkldnn::memory::format::oi, PARAM_FORMAT_MKLDNN_OI}}; - class MKLDNNMatrix; typedef std::shared_ptr MKLDNNMatrixPtr; @@ -57,6 +54,26 @@ public: mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32); public: + /** + * Reorder this MKLDNNMatrix from other format. + * Support inplace reorder + * Pay attention: this function would only reorder the data layout. + * will NOT change this original dim or format info + */ + void reorderDataFrom(const MKLDNNMatrixPtr& m, + memory::format srcFmt, + memory::dims targetDim); + + /** + * Reorder this MKLDNNMatrix to other format. + * Support inplace reorder + * Pay attention: this function would only reorder the data layout. + * will NOT change the dst dim or format info + */ + void reorderDataTo(const MKLDNNMatrixPtr& m, + memory::format dstFmt, + memory::dims targetDim); + /** * Dimensionality reduction. * Change format "nchw --> nc" or "oihw --> oi" if the h and w are both 1 @@ -113,6 +130,16 @@ public: * Get engine. */ mkldnn::engine getEngine() { return getPD().get_engine(); } + +protected: + /** + * Do once reorder supported inplace. + */ + void reorderOnce(void* srcData, + void* dstData, + memory::format srcFmt, + memory::format dstFmt, + memory::dims dm); }; } // namespace paddle