clean the code a little bit.

avx_docs
xutianbing 8 years ago
parent 4751cc8f7e
commit 171eaff216

@ -38,13 +38,6 @@ inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
}
}
inline void colVecAddTo(
real* a, const real* b, size_t len, size_t aWidth, size_t bWidth) {
for (unsigned int i = 0; i < len; ++i) {
a[i * aWidth] += b[i * bWidth];
}
}
inline void colVecAddTo(
real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) {
for (unsigned int i = 0; i < len; ++i) {
@ -336,140 +329,59 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuSparseMatrix& b,
real scaleAB,
real scaleT) {
/// todo(tianbing), clean the code
CHECK(!out.trans_) << "Not supported";
CHECK(!a.isTransposed()) << "Not supported";
CHECK(scaleT == 0 || scaleT == 1);
CHECK_EQ(scaleAB, static_cast<real>(1.0));
if (!b.isTransposed()) { /// b is not Transpose
CHECK(b.getHeight() == a.getWidth() && a.getHeight() == out.getHeight() &&
b.getWidth() == out.getWidth());
} else {
CHECK(b.getHeight() == out.getWidth() && a.getHeight() == out.getHeight() &&
b.getWidth() == a.getWidth());
}
if (scaleT == 0) {
out.zeroMem();
}
real* A = const_cast<real*>(a.getData());
real* B = const_cast<real*>(b.getValue());
real* C = out.getData();
int* rows = b.getRows();
int* cols = b.getCols();
if (scaleT == 0) {
out.zeroMem();
}
/// todo(tianbing), clean the code
/// b.getFormat() == SPARSE_CSC
if (b.getFormat() == SPARSE_CSC) {
if (!b.isTransposed()) {
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), m);
CHECK_EQ(a.getHeight(), out.height_);
CHECK_EQ(b.getWidth(), out.width_);
if (b.getValueType() == NO_VALUE) {
for (size_t j = 0; j < b.getWidth(); ++j) {
int start = b.getColStartIdx(j);
int end = b.getColStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(
C + j, A + rows[i], out.height_, out.width_, a.getWidth());
}
}
} else if (b.getValueType() == FLOAT_VALUE) {
for (size_t j = 0; j < b.getWidth(); ++j) {
int start = b.getColStartIdx(j);
int end = b.getColStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(C + j,
A + rows[i],
B[i],
out.height_,
out.width_,
a.getWidth());
}
}
}
} else /*if (b.isTransposed())*/ {
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), out.width_);
CHECK_EQ(a.getHeight(), out.height_);
CHECK_EQ(b.getWidth(), m);
if (b.getValueType() == NO_VALUE) {
for (size_t i = 0; i < b.getWidth(); ++i) {
int start = b.getColStartIdx(i);
int end = b.getColStartIdx(i + 1);
for (int j = start; j < end; ++j) {
colVecAddTo(
C + rows[j], A + i, out.height_, out.width_, a.getWidth());
}
}
} else if (b.getValueType() == FLOAT_VALUE) {
for (size_t i = 0; i < b.getWidth(); ++i) {
int start = b.getColStartIdx(i);
int end = b.getColStartIdx(i + 1);
for (int j = start; j < end; ++j) {
colVecAddTo(C + rows[j],
A + i,
B[j],
out.height_,
out.width_,
a.getWidth());
}
}
for (size_t j = 0; j < b.getWidth(); ++j) {
int start = b.getColStartIdx(j);
int end = b.getColStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(!b.isTransposed() ? C + j : C + rows[i],
!b.isTransposed() ? A + rows[i] : A + j,
(b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
out.getHeight(),
out.getWidth(),
a.getWidth());
}
}
} else {
if (!b.isTransposed()) {
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), m);
CHECK_EQ(a.getHeight(), out.height_);
CHECK_EQ(b.getWidth(), out.width_);
if (b.getValueType() == NO_VALUE) {
for (size_t j = 0; j < b.getHeight(); ++j) {
int start = b.getRowStartIdx(j);
int end = b.getRowStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(
C + cols[i], A + j, out.height_, out.width_, a.getWidth());
}
}
} else if (b.getValueType() == FLOAT_VALUE) {
for (size_t j = 0; j < b.getHeight(); ++j) {
int start = b.getRowStartIdx(j);
int end = b.getRowStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(C + cols[i],
A + j,
B[i],
out.height_,
out.width_,
a.getWidth());
}
}
}
} else /*if (b.isTransposed())*/ {
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), out.width_);
CHECK_EQ(a.getHeight(), out.height_);
CHECK_EQ(b.getWidth(), m);
if (b.getValueType() == NO_VALUE) {
for (size_t i = 0; i < b.getHeight(); ++i) {
int start = b.getRowStartIdx(i);
int end = b.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
colVecAddTo(
C + i, A + cols[j], out.height_, out.width_, a.getWidth());
}
}
} else if (b.getValueType() == FLOAT_VALUE) {
for (size_t i = 0; i < b.getHeight(); ++i) {
int start = b.getRowStartIdx(i);
int end = b.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
colVecAddTo(C + i,
A + cols[j],
B[j],
out.height_,
out.width_,
a.getWidth());
}
}
return;
}
/// b.getFormat() == SPARSE_CSR
if (b.getFormat() == SPARSE_CSR) {
for (size_t j = 0; j < b.getHeight(); ++j) {
int start = b.getRowStartIdx(j);
int end = b.getRowStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(!b.isTransposed() ? C + cols[i] : C + j,
!b.isTransposed() ? A + j : A + cols[i],
(b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
out.getHeight(),
out.getWidth(),
a.getWidth());
}
}
return;
}
}

File diff suppressed because it is too large Load Diff

@ -76,12 +76,12 @@ void testDDDMatrix(bool transa, bool transb, int dimM, int dimN, int dimK) {
TEST(Matrix, DDDMul) {
LOG(INFO) << "test for dense = dense * dense matrix";
for (auto transa : {false, true}) {
for (auto transb : {false, true}) {
for (auto dimM : {1, 10, 100}) {
for (auto dimN : {1, 10}) {
for (auto dimK : {8}) {
if (true == transa && true == transb) {
for (const auto transa : {false, true}) {
for (const auto transb : {false, true}) {
for (const auto dimM : {1, 10, 100}) {
for (const auto dimN : {1, 10}) {
for (const auto dimK : {8}) {
if (transa && transb) {
continue;
}
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
@ -89,7 +89,6 @@ TEST(Matrix, DDDMul) {
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK;
testDDDMatrix(transa, transb, dimM, dimN, dimK);
}
}

Loading…
Cancel
Save