|
|
|
@ -18,6 +18,9 @@ limitations under the License. */
|
|
|
|
|
#include "mkldnn.hpp"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
typedef mkldnn::inner_product_forward fc_fwd;
|
|
|
|
|
typedef mkldnn::inner_product_backward_weights fc_bwdWgt;
|
|
|
|
|
typedef mkldnn::inner_product_backward_data fc_bwdData;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief A subclass of MKLDNNLayer fc layer.
|
|
|
|
@ -32,6 +35,9 @@ protected:
|
|
|
|
|
// if has already init the weight
|
|
|
|
|
bool hasInitedWgt_;
|
|
|
|
|
|
|
|
|
|
// save forward primitive_desc, which can be used backward
|
|
|
|
|
std::shared_ptr<fc_fwd::primitive_desc> fwdPD_;
|
|
|
|
|
|
|
|
|
|
// fc weight and bias
|
|
|
|
|
std::unique_ptr<Weight> weight_;
|
|
|
|
|
std::unique_ptr<Weight> biases_;
|
|
|
|
@ -67,6 +73,59 @@ public:
|
|
|
|
|
void convertWeightsFromPaddle() override;
|
|
|
|
|
|
|
|
|
|
void convertWeightsToPaddle() override;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
/**
|
|
|
|
|
* Forward functions: reset buffers(input, output, weight and bias),
|
|
|
|
|
* reset primitive descriptor,
|
|
|
|
|
* reset pipeline.
|
|
|
|
|
*/
|
|
|
|
|
void resetFwdBuffers(MKLDNNMatrixPtr& in,
|
|
|
|
|
MKLDNNMatrixPtr& wgt,
|
|
|
|
|
MKLDNNMatrixPtr& bias,
|
|
|
|
|
MKLDNNMatrixPtr& out);
|
|
|
|
|
void resetInValue(MKLDNNMatrixPtr& in);
|
|
|
|
|
void resetWgtBiasValue(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias);
|
|
|
|
|
void resetOutValue(MKLDNNMatrixPtr& out);
|
|
|
|
|
void resetFwdPD(std::shared_ptr<fc_fwd::primitive_desc>& pd,
|
|
|
|
|
MKLDNNMatrixPtr in,
|
|
|
|
|
MKLDNNMatrixPtr wgt,
|
|
|
|
|
MKLDNNMatrixPtr bias,
|
|
|
|
|
MKLDNNMatrixPtr out);
|
|
|
|
|
void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline,
|
|
|
|
|
std::shared_ptr<fc_fwd::primitive_desc>& pd,
|
|
|
|
|
MKLDNNMatrixPtr& in,
|
|
|
|
|
MKLDNNMatrixPtr& wgt,
|
|
|
|
|
MKLDNNMatrixPtr& bias,
|
|
|
|
|
MKLDNNMatrixPtr& out);
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Backward functions: reset buffers(input, output, weight and bias),
|
|
|
|
|
* reset primitive descriptor for backward weight,
|
|
|
|
|
* reset primitive descriptor for backward data,
|
|
|
|
|
* reset pipeline.
|
|
|
|
|
*/
|
|
|
|
|
void resetBwdBuffers(MKLDNNMatrixPtr& in,
|
|
|
|
|
MKLDNNMatrixPtr& wgt,
|
|
|
|
|
MKLDNNMatrixPtr& bias,
|
|
|
|
|
MKLDNNMatrixPtr& out);
|
|
|
|
|
void resetOutGrad(MKLDNNMatrixPtr& out);
|
|
|
|
|
void resetWgtBiasGrad(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias);
|
|
|
|
|
void resetInGrad(MKLDNNMatrixPtr& in);
|
|
|
|
|
void resetBwdWgtPD(std::shared_ptr<fc_bwdWgt::primitive_desc>& pd,
|
|
|
|
|
MKLDNNMatrixPtr& wgt,
|
|
|
|
|
MKLDNNMatrixPtr& bias,
|
|
|
|
|
MKLDNNMatrixPtr& out);
|
|
|
|
|
void resetBwdDataPD(std::shared_ptr<fc_bwdData::primitive_desc>& pd,
|
|
|
|
|
MKLDNNMatrixPtr& in,
|
|
|
|
|
MKLDNNMatrixPtr& out);
|
|
|
|
|
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
|
|
|
|
|
std::shared_ptr<fc_bwdWgt::primitive_desc>& bwdWgtPD,
|
|
|
|
|
std::shared_ptr<fc_bwdData::primitive_desc>& bwdDataPD,
|
|
|
|
|
MKLDNNMatrixPtr& in,
|
|
|
|
|
MKLDNNMatrixPtr& wgt,
|
|
|
|
|
MKLDNNMatrixPtr& bias,
|
|
|
|
|
MKLDNNMatrixPtr& out);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace paddle
|
|
|
|
|