|
|
|
@ -38,13 +38,6 @@ inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void colVecAddTo(
|
|
|
|
|
real* a, const real* b, size_t len, size_t aWidth, size_t bWidth) {
|
|
|
|
|
for (unsigned int i = 0; i < len; ++i) {
|
|
|
|
|
a[i * aWidth] += b[i * bWidth];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void colVecAddTo(
|
|
|
|
|
real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) {
|
|
|
|
|
for (unsigned int i = 0; i < len; ++i) {
|
|
|
|
@ -336,140 +329,59 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
|
|
|
|
|
const CpuSparseMatrix& b,
|
|
|
|
|
real scaleAB,
|
|
|
|
|
real scaleT) {
|
|
|
|
|
/// todo(tianbing), clean the code
|
|
|
|
|
CHECK(!out.trans_) << "Not supported";
|
|
|
|
|
CHECK(!a.isTransposed()) << "Not supported";
|
|
|
|
|
CHECK(scaleT == 0 || scaleT == 1);
|
|
|
|
|
CHECK_EQ(scaleAB, static_cast<real>(1.0));
|
|
|
|
|
if (!b.isTransposed()) { /// b is not Transpose
|
|
|
|
|
CHECK(b.getHeight() == a.getWidth() && a.getHeight() == out.getHeight() &&
|
|
|
|
|
b.getWidth() == out.getWidth());
|
|
|
|
|
} else {
|
|
|
|
|
CHECK(b.getHeight() == out.getWidth() && a.getHeight() == out.getHeight() &&
|
|
|
|
|
b.getWidth() == a.getWidth());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (scaleT == 0) {
|
|
|
|
|
out.zeroMem();
|
|
|
|
|
}
|
|
|
|
|
real* A = const_cast<real*>(a.getData());
|
|
|
|
|
real* B = const_cast<real*>(b.getValue());
|
|
|
|
|
real* C = out.getData();
|
|
|
|
|
int* rows = b.getRows();
|
|
|
|
|
int* cols = b.getCols();
|
|
|
|
|
|
|
|
|
|
if (scaleT == 0) {
|
|
|
|
|
out.zeroMem();
|
|
|
|
|
}
|
|
|
|
|
/// todo(tianbing), clean the code
|
|
|
|
|
/// b.getFormat() == SPARSE_CSC
|
|
|
|
|
if (b.getFormat() == SPARSE_CSC) {
|
|
|
|
|
if (!b.isTransposed()) {
|
|
|
|
|
size_t m = a.getWidth();
|
|
|
|
|
CHECK_EQ(b.getHeight(), m);
|
|
|
|
|
CHECK_EQ(a.getHeight(), out.height_);
|
|
|
|
|
CHECK_EQ(b.getWidth(), out.width_);
|
|
|
|
|
|
|
|
|
|
if (b.getValueType() == NO_VALUE) {
|
|
|
|
|
for (size_t j = 0; j < b.getWidth(); ++j) {
|
|
|
|
|
int start = b.getColStartIdx(j);
|
|
|
|
|
int end = b.getColStartIdx(j + 1);
|
|
|
|
|
for (int i = start; i < end; ++i) {
|
|
|
|
|
colVecAddTo(
|
|
|
|
|
C + j, A + rows[i], out.height_, out.width_, a.getWidth());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (b.getValueType() == FLOAT_VALUE) {
|
|
|
|
|
for (size_t j = 0; j < b.getWidth(); ++j) {
|
|
|
|
|
int start = b.getColStartIdx(j);
|
|
|
|
|
int end = b.getColStartIdx(j + 1);
|
|
|
|
|
for (int i = start; i < end; ++i) {
|
|
|
|
|
colVecAddTo(C + j,
|
|
|
|
|
A + rows[i],
|
|
|
|
|
B[i],
|
|
|
|
|
out.height_,
|
|
|
|
|
out.width_,
|
|
|
|
|
a.getWidth());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else /*if (b.isTransposed())*/ {
|
|
|
|
|
size_t m = a.getWidth();
|
|
|
|
|
CHECK_EQ(b.getHeight(), out.width_);
|
|
|
|
|
CHECK_EQ(a.getHeight(), out.height_);
|
|
|
|
|
CHECK_EQ(b.getWidth(), m);
|
|
|
|
|
if (b.getValueType() == NO_VALUE) {
|
|
|
|
|
for (size_t i = 0; i < b.getWidth(); ++i) {
|
|
|
|
|
int start = b.getColStartIdx(i);
|
|
|
|
|
int end = b.getColStartIdx(i + 1);
|
|
|
|
|
for (int j = start; j < end; ++j) {
|
|
|
|
|
colVecAddTo(
|
|
|
|
|
C + rows[j], A + i, out.height_, out.width_, a.getWidth());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (b.getValueType() == FLOAT_VALUE) {
|
|
|
|
|
for (size_t i = 0; i < b.getWidth(); ++i) {
|
|
|
|
|
int start = b.getColStartIdx(i);
|
|
|
|
|
int end = b.getColStartIdx(i + 1);
|
|
|
|
|
for (int j = start; j < end; ++j) {
|
|
|
|
|
colVecAddTo(C + rows[j],
|
|
|
|
|
A + i,
|
|
|
|
|
B[j],
|
|
|
|
|
out.height_,
|
|
|
|
|
out.width_,
|
|
|
|
|
a.getWidth());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (size_t j = 0; j < b.getWidth(); ++j) {
|
|
|
|
|
int start = b.getColStartIdx(j);
|
|
|
|
|
int end = b.getColStartIdx(j + 1);
|
|
|
|
|
for (int i = start; i < end; ++i) {
|
|
|
|
|
colVecAddTo(!b.isTransposed() ? C + j : C + rows[i],
|
|
|
|
|
!b.isTransposed() ? A + rows[i] : A + j,
|
|
|
|
|
(b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
|
|
|
|
|
out.getHeight(),
|
|
|
|
|
out.getWidth(),
|
|
|
|
|
a.getWidth());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (!b.isTransposed()) {
|
|
|
|
|
size_t m = a.getWidth();
|
|
|
|
|
CHECK_EQ(b.getHeight(), m);
|
|
|
|
|
CHECK_EQ(a.getHeight(), out.height_);
|
|
|
|
|
CHECK_EQ(b.getWidth(), out.width_);
|
|
|
|
|
|
|
|
|
|
if (b.getValueType() == NO_VALUE) {
|
|
|
|
|
for (size_t j = 0; j < b.getHeight(); ++j) {
|
|
|
|
|
int start = b.getRowStartIdx(j);
|
|
|
|
|
int end = b.getRowStartIdx(j + 1);
|
|
|
|
|
for (int i = start; i < end; ++i) {
|
|
|
|
|
colVecAddTo(
|
|
|
|
|
C + cols[i], A + j, out.height_, out.width_, a.getWidth());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (b.getValueType() == FLOAT_VALUE) {
|
|
|
|
|
for (size_t j = 0; j < b.getHeight(); ++j) {
|
|
|
|
|
int start = b.getRowStartIdx(j);
|
|
|
|
|
int end = b.getRowStartIdx(j + 1);
|
|
|
|
|
for (int i = start; i < end; ++i) {
|
|
|
|
|
colVecAddTo(C + cols[i],
|
|
|
|
|
A + j,
|
|
|
|
|
B[i],
|
|
|
|
|
out.height_,
|
|
|
|
|
out.width_,
|
|
|
|
|
a.getWidth());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else /*if (b.isTransposed())*/ {
|
|
|
|
|
size_t m = a.getWidth();
|
|
|
|
|
CHECK_EQ(b.getHeight(), out.width_);
|
|
|
|
|
CHECK_EQ(a.getHeight(), out.height_);
|
|
|
|
|
CHECK_EQ(b.getWidth(), m);
|
|
|
|
|
if (b.getValueType() == NO_VALUE) {
|
|
|
|
|
for (size_t i = 0; i < b.getHeight(); ++i) {
|
|
|
|
|
int start = b.getRowStartIdx(i);
|
|
|
|
|
int end = b.getRowStartIdx(i + 1);
|
|
|
|
|
for (int j = start; j < end; ++j) {
|
|
|
|
|
colVecAddTo(
|
|
|
|
|
C + i, A + cols[j], out.height_, out.width_, a.getWidth());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (b.getValueType() == FLOAT_VALUE) {
|
|
|
|
|
for (size_t i = 0; i < b.getHeight(); ++i) {
|
|
|
|
|
int start = b.getRowStartIdx(i);
|
|
|
|
|
int end = b.getRowStartIdx(i + 1);
|
|
|
|
|
for (int j = start; j < end; ++j) {
|
|
|
|
|
colVecAddTo(C + i,
|
|
|
|
|
A + cols[j],
|
|
|
|
|
B[j],
|
|
|
|
|
out.height_,
|
|
|
|
|
out.width_,
|
|
|
|
|
a.getWidth());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// b.getFormat() == SPARSE_CSR
|
|
|
|
|
if (b.getFormat() == SPARSE_CSR) {
|
|
|
|
|
for (size_t j = 0; j < b.getHeight(); ++j) {
|
|
|
|
|
int start = b.getRowStartIdx(j);
|
|
|
|
|
int end = b.getRowStartIdx(j + 1);
|
|
|
|
|
for (int i = start; i < end; ++i) {
|
|
|
|
|
colVecAddTo(!b.isTransposed() ? C + cols[i] : C + j,
|
|
|
|
|
!b.isTransposed() ? A + j : A + cols[i],
|
|
|
|
|
(b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
|
|
|
|
|
out.getHeight(),
|
|
|
|
|
out.getWidth(),
|
|
|
|
|
a.getWidth());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|