|
|
|
@ -30,11 +30,10 @@ typedef std::shared_ptr<MKLDNNMatrix> MKLDNNMatrixPtr;
|
|
|
|
|
*/
|
|
|
|
|
class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory {
|
|
|
|
|
public:
|
|
|
|
|
MKLDNNMatrix(real* data,
|
|
|
|
|
size_t height,
|
|
|
|
|
size_t width,
|
|
|
|
|
mkldnn::memory::primitive_desc pd)
|
|
|
|
|
: CpuMatrix(data, height, width, false), mkldnn::memory(pd, data) {}
|
|
|
|
|
MKLDNNMatrix(CpuMatrixPtr m, mkldnn::memory::primitive_desc pd)
|
|
|
|
|
: CpuMatrix(m->getData(), m->getHeight(), m->getWidth(), false),
|
|
|
|
|
mkldnn::memory(pd, m->getData()),
|
|
|
|
|
m_(m) {}
|
|
|
|
|
|
|
|
|
|
~MKLDNNMatrix() {}
|
|
|
|
|
|
|
|
|
@ -81,11 +80,29 @@ public:
|
|
|
|
|
void downSpatial();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Update the memory data handle.
|
|
|
|
|
* set the memory data handle.
|
|
|
|
|
* Caution: This will not check the buffer size of the data,
|
|
|
|
|
* it should be coverd by user.
|
|
|
|
|
*/
|
|
|
|
|
void updateData(void* data) { set_data_handle(data); }
|
|
|
|
|
void setData(real* data) {
|
|
|
|
|
set_data_handle(data);
|
|
|
|
|
CpuMatrix::setData(data);
|
|
|
|
|
m_.reset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* override Matrix::getData
|
|
|
|
|
* check data before return
|
|
|
|
|
*/
|
|
|
|
|
real* getData() override {
|
|
|
|
|
CHECK_EQ((void*)data_, get_data_handle());
|
|
|
|
|
return data_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const real* getData() const override {
|
|
|
|
|
CHECK_EQ((void*)data_, get_data_handle());
|
|
|
|
|
return data_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Get primitive descriptor.
|
|
|
|
@ -143,6 +160,10 @@ protected:
|
|
|
|
|
memory::format srcFmt,
|
|
|
|
|
memory::format dstFmt,
|
|
|
|
|
memory::dims dm);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// save the CpuMatrixPtr in case the buffer released outside
|
|
|
|
|
CpuMatrixPtr m_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace paddle
|
|
|
|
|