|
|
|
@ -56,6 +56,63 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m,
|
|
|
|
|
return create(m, pd);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNMatrix::reorderDataFrom(const MKLDNNMatrixPtr& m,
|
|
|
|
|
memory::format srcFmt,
|
|
|
|
|
memory::dims targetDim) {
|
|
|
|
|
memory::format dstFmt = getFormat();
|
|
|
|
|
if (srcFmt == dstFmt) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
CHECK_EQ(getElementCnt(), m->getElementCnt()) << "size should equal";
|
|
|
|
|
real* srcData = getData();
|
|
|
|
|
real* dstData = m->getData();
|
|
|
|
|
reorderOnce(srcData, dstData, srcFmt, dstFmt, targetDim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNMatrix::reorderDataTo(const MKLDNNMatrixPtr& m,
|
|
|
|
|
memory::format dstFmt,
|
|
|
|
|
memory::dims targetDim) {
|
|
|
|
|
memory::format srcFmt = getFormat();
|
|
|
|
|
if (srcFmt == dstFmt) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
CHECK_EQ(getElementCnt(), m->getElementCnt()) << "size should equal";
|
|
|
|
|
real* srcData = getData();
|
|
|
|
|
real* dstData = m->getData();
|
|
|
|
|
reorderOnce(srcData, dstData, srcFmt, dstFmt, targetDim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNMatrix::reorderOnce(void* srcData,
|
|
|
|
|
void* dstData,
|
|
|
|
|
memory::format srcFmt,
|
|
|
|
|
memory::format dstFmt,
|
|
|
|
|
memory::dims dm) {
|
|
|
|
|
CHECK(srcData);
|
|
|
|
|
CHECK(dstData);
|
|
|
|
|
MatrixPtr tmpSrc;
|
|
|
|
|
if (dstData == srcData) {
|
|
|
|
|
// inplace data
|
|
|
|
|
size_t sz = 1;
|
|
|
|
|
for (size_t i = 0; i < dm.size(); ++i) {
|
|
|
|
|
sz *= dm[i];
|
|
|
|
|
}
|
|
|
|
|
tmpSrc = Matrix::create(sz, 1, false, false);
|
|
|
|
|
tmpSrc->copyFrom((real*)srcData, sz);
|
|
|
|
|
srcData = tmpSrc->getData();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dtype = this->getDtype();
|
|
|
|
|
auto srcMD = memory::desc(dm, dtype, srcFmt);
|
|
|
|
|
auto dstMD = memory::desc(dm, dtype, dstFmt);
|
|
|
|
|
|
|
|
|
|
auto eg = this->getEngine();
|
|
|
|
|
auto src = memory(memory::primitive_desc(srcMD, eg), srcData);
|
|
|
|
|
auto dst = memory(memory::primitive_desc(dstMD, eg), dstData);
|
|
|
|
|
|
|
|
|
|
auto r = reorder(src, dst);
|
|
|
|
|
stream(stream::kind::eager).submit({r}).wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNMatrix::downSpatial() {
|
|
|
|
|
int fmt = getFormat();
|
|
|
|
|
if (!(fmt == memory::format::nchw || fmt == memory::format::oihw)) {
|
|
|
|
|