|
|
|
@ -145,6 +145,27 @@ public:
|
|
|
|
|
m_.reset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* override the CpuMatrix::resize
|
|
|
|
|
*/
|
|
|
|
|
void resize(size_t newHeight, size_t newWidth) override {
|
|
|
|
|
m_->resize(newHeight, newWidth);
|
|
|
|
|
if (data_ == m_->getData() && elementCnt_ == newHeight * newWidth) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
CpuMatrix::setData(data_);
|
|
|
|
|
height_ = newHeight;
|
|
|
|
|
width_ = newWidth;
|
|
|
|
|
elementCnt_ = newHeight * newWidth;
|
|
|
|
|
stride_ = width_;
|
|
|
|
|
auto pd = mkldnn::memory::primitive_desc(
|
|
|
|
|
mkldnn::memory::desc({(int)newHeight, (int)newWidth},
|
|
|
|
|
getDtype(),
|
|
|
|
|
mkldnn::memory::format::nc),
|
|
|
|
|
getEngine());
|
|
|
|
|
resetMKLDNNMemory(pd, data_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* override Matrix::getData
|
|
|
|
|
* check data before return
|
|
|
|
@ -215,6 +236,17 @@ protected:
|
|
|
|
|
memory::format srcFmt,
|
|
|
|
|
memory::format dstFmt,
|
|
|
|
|
memory::dims dm);
|
|
|
|
|
/**
|
|
|
|
|
* reset this MKLDNN Memory from primitve desc
|
|
|
|
|
*/
|
|
|
|
|
void resetMKLDNNMemory(memory::primitive_desc pd, real* data) {
|
|
|
|
|
mkldnn_primitive_t result;
|
|
|
|
|
mkldnn::error::wrap_c_api(
|
|
|
|
|
mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr),
|
|
|
|
|
"could not create a memory primitive");
|
|
|
|
|
reset(result);
|
|
|
|
|
set_data_handle(data);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// save the CpuMatrixPtr in case the buffer released outside
|
|
|
|
|