|
|
|
@ -26,22 +26,16 @@ limitations under the License. */
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
inline void vecAddTo(real* a, const real* b, size_t len) {
|
|
|
|
|
for (unsigned int i = 0; i < len; ++i) {
|
|
|
|
|
a[i] += b[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
|
|
|
|
|
for (unsigned int i = 0; i < len; ++i) {
|
|
|
|
|
a[i] += scaleB * b[i];
|
|
|
|
|
a[i] += (1.0 == scaleB) ? b[i] : scaleB * b[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void colVecAddTo(
|
|
|
|
|
real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) {
|
|
|
|
|
for (unsigned int i = 0; i < len; ++i) {
|
|
|
|
|
a[i * aWidth] += b[i * bWidth] * c;
|
|
|
|
|
a[i * aWidth] += (1.0 == c) ? b[i * bWidth] : b[i * bWidth] * c;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
@ -53,15 +47,19 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
|
|
|
|
|
const CpuMatrix& b,
|
|
|
|
|
real scaleAB,
|
|
|
|
|
real scaleT) {
|
|
|
|
|
/// todo(tianbing), clean the code
|
|
|
|
|
CHECK(!out.isTransposed()) << "Not supported";
|
|
|
|
|
CHECK_EQ(out.getValueType(), FLOAT_VALUE);
|
|
|
|
|
CHECK(!a.isTransposed() || !b.isTransposed())
|
|
|
|
|
<< "Not support both a and b are transpose matrices";
|
|
|
|
|
if (!a.isTransposed() && b.isTransposed()) {
|
|
|
|
|
CHECK(out.getFormat() != SPARSE_CSC)
|
|
|
|
|
<< "Not supported CSC format when a is not trans and b is trans";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t height = out.getHeight();
|
|
|
|
|
size_t width = out.getWidth();
|
|
|
|
|
size_t aRow = !a.isTransposed() ? a.getHeight() : a.getWidth();
|
|
|
|
|
size_t aCol = !a.isTransposed() ? a.getWidth() : a.getHeight();
|
|
|
|
|
size_t bRow = !b.isTransposed() ? b.getHeight() : b.getWidth();
|
|
|
|
|
size_t bCol = !b.isTransposed() ? b.getWidth() : b.getHeight();
|
|
|
|
|
/// C = A * B, for matrix format
|
|
|
|
|
CHECK(aCol == bRow && aRow == height && bCol == width);
|
|
|
|
|
|
|
|
|
|
if (scaleT == 0) {
|
|
|
|
|
out.zeroMem();
|
|
|
|
@ -71,93 +69,46 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
|
|
|
|
|
real* C = out.getValue();
|
|
|
|
|
int* rows = out.getRows();
|
|
|
|
|
int* cols = out.getCols();
|
|
|
|
|
size_t height = out.getHeight();
|
|
|
|
|
size_t width = out.getWidth();
|
|
|
|
|
|
|
|
|
|
if (!a.isTransposed() && !b.isTransposed()) {
|
|
|
|
|
CHECK(b.getHeight() == a.getWidth() && a.getHeight() == height &&
|
|
|
|
|
b.getWidth() == width);
|
|
|
|
|
size_t m = a.getWidth();
|
|
|
|
|
if (out.getFormat() == SPARSE_CSC) {
|
|
|
|
|
for (size_t i = 0; i < width; i++) {
|
|
|
|
|
size_t start = out.getColStartIdx(i);
|
|
|
|
|
size_t end = out.getColStartIdx(i + 1);
|
|
|
|
|
for (size_t j = start; j < end; j++) {
|
|
|
|
|
real sum = 0;
|
|
|
|
|
size_t rowIdx = rows[j];
|
|
|
|
|
for (size_t k = 0; k < m; k++) {
|
|
|
|
|
sum += A[rowIdx * m + k] * B[k * width + i];
|
|
|
|
|
}
|
|
|
|
|
C[j] = scaleAB * sum + scaleT * C[j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else { /// out.getFormat() == SPARSE_CSR
|
|
|
|
|
for (size_t i = 0; i < height; i++) {
|
|
|
|
|
size_t start = out.getRowStartIdx(i);
|
|
|
|
|
size_t end = out.getRowStartIdx(i + 1);
|
|
|
|
|
for (size_t j = start; j < end; j++) {
|
|
|
|
|
real sum = 0;
|
|
|
|
|
size_t colIdx = cols[j];
|
|
|
|
|
for (size_t k = 0; k < a.getWidth(); k++) {
|
|
|
|
|
sum += A[i * m + k] * B[k * width + colIdx];
|
|
|
|
|
}
|
|
|
|
|
C[j] = scaleAB * sum + scaleT * C[j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (a.isTransposed() && !b.isTransposed()) {
|
|
|
|
|
CHECK(a.getHeight() == b.getHeight() && b.getWidth() == width &&
|
|
|
|
|
a.getWidth() == height);
|
|
|
|
|
size_t m = a.getHeight();
|
|
|
|
|
if (out.getFormat() == SPARSE_CSC) {
|
|
|
|
|
for (size_t i = 0; i < width; i++) {
|
|
|
|
|
size_t start = out.getColStartIdx(i);
|
|
|
|
|
size_t end = out.getColStartIdx(i + 1);
|
|
|
|
|
for (size_t j = start; j < end; j++) {
|
|
|
|
|
real sum = 0;
|
|
|
|
|
size_t rowIdx = rows[j];
|
|
|
|
|
for (size_t k = 0; k < m; k++) {
|
|
|
|
|
sum += A[k * height + rowIdx] * B[k * width + i];
|
|
|
|
|
}
|
|
|
|
|
C[j] = scaleAB * sum + scaleT * C[j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else { /// out.getFormat() == SPARSE_CSR
|
|
|
|
|
for (size_t i = 0; i < height; i++) {
|
|
|
|
|
int start = out.getRowStartIdx(i);
|
|
|
|
|
int end = out.getRowStartIdx(i + 1);
|
|
|
|
|
for (int j = start; j < end; j++) {
|
|
|
|
|
real sum = 0;
|
|
|
|
|
size_t colIdx = cols[j];
|
|
|
|
|
for (size_t k = 0; k < a.getHeight(); k++) {
|
|
|
|
|
sum += A[k * height + i] * B[k * width + colIdx];
|
|
|
|
|
}
|
|
|
|
|
C[j] = scaleAB * sum + scaleT * C[j];
|
|
|
|
|
/// SPARSE_CSC, {a any, b not trans}
|
|
|
|
|
if (out.getFormat() == SPARSE_CSC) {
|
|
|
|
|
/// b not trans and a any
|
|
|
|
|
CHECK(!b.isTransposed());
|
|
|
|
|
size_t m = !a.isTransposed() ? a.getWidth() : a.getHeight();
|
|
|
|
|
for (size_t i = 0; i < width; i++) {
|
|
|
|
|
size_t start = out.getColStartIdx(i);
|
|
|
|
|
size_t end = out.getColStartIdx(i + 1);
|
|
|
|
|
for (size_t j = start; j < end; j++) {
|
|
|
|
|
real sum = 0;
|
|
|
|
|
size_t rowIdx = rows[j];
|
|
|
|
|
for (size_t k = 0; k < m; k++) {
|
|
|
|
|
sum +=
|
|
|
|
|
(!a.isTransposed() ? A[rowIdx * m + k] : A[k * height + rowIdx]) *
|
|
|
|
|
B[k * width + i];
|
|
|
|
|
}
|
|
|
|
|
C[j] = scaleAB * sum + scaleT * C[j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!a.isTransposed() && b.isTransposed()) {
|
|
|
|
|
CHECK(b.getWidth() == a.getWidth() && a.getHeight() == height &&
|
|
|
|
|
b.getHeight() == width);
|
|
|
|
|
/// SPARSE_CSR, {a any, b not trans} or {a not trans, b trans}
|
|
|
|
|
if (out.getFormat() == SPARSE_CSR) {
|
|
|
|
|
/// a and b can not both transpose
|
|
|
|
|
CHECK(!(a.isTransposed() && b.isTransposed()));
|
|
|
|
|
size_t m = a.getWidth();
|
|
|
|
|
if (out.getFormat() == SPARSE_CSR) {
|
|
|
|
|
for (size_t i = 0; i < height; i++) {
|
|
|
|
|
size_t start = out.getRowStartIdx(i);
|
|
|
|
|
size_t end = out.getRowStartIdx(i + 1);
|
|
|
|
|
for (size_t j = start; j < end; j++) {
|
|
|
|
|
real sum = 0;
|
|
|
|
|
size_t colIdx = cols[j];
|
|
|
|
|
for (size_t k = 0; k < m; k++) {
|
|
|
|
|
sum += A[i * m + k] * B[colIdx * m + k];
|
|
|
|
|
}
|
|
|
|
|
C[j] = scaleAB * sum + scaleT * C[j];
|
|
|
|
|
for (size_t i = 0; i < height; i++) {
|
|
|
|
|
size_t start = out.getRowStartIdx(i);
|
|
|
|
|
size_t end = out.getRowStartIdx(i + 1);
|
|
|
|
|
for (size_t j = start; j < end; j++) {
|
|
|
|
|
real sum = 0;
|
|
|
|
|
size_t colIdx = cols[j];
|
|
|
|
|
for (size_t k = 0; k < m; k++) {
|
|
|
|
|
sum +=
|
|
|
|
|
(!a.isTransposed() ? A[i * m + k] : A[k * height + i]) *
|
|
|
|
|
(!b.isTransposed() ? B[k * width + colIdx] : B[colIdx * m + k]);
|
|
|
|
|
}
|
|
|
|
|
C[j] = scaleAB * sum + scaleT * C[j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
@ -330,11 +281,11 @@ public:
|
|
|
|
|
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
|
|
|
|
|
|
|
|
|
|
auto out_mat = outputs[0].matrix<Device>();
|
|
|
|
|
auto outMat = outputs[0].matrix<Device>();
|
|
|
|
|
/// matrix = matrix * matrix
|
|
|
|
|
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
|
|
|
|
|
!outputs[0].isSparseArg()) {
|
|
|
|
|
MulOp<Device>(out_mat,
|
|
|
|
|
MulOp<Device>(outMat,
|
|
|
|
|
inputs[0].matrix<Device>(),
|
|
|
|
|
inputs[1].matrix<Device>(),
|
|
|
|
|
alpha_,
|
|
|
|
@ -345,7 +296,7 @@ public:
|
|
|
|
|
/// matrix = matrix * sparse matrix
|
|
|
|
|
if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
|
|
|
|
|
!outputs[0].isSparseArg()) {
|
|
|
|
|
MulOp<Device>(out_mat,
|
|
|
|
|
MulOp<Device>(outMat,
|
|
|
|
|
inputs[0].matrix<Device>(),
|
|
|
|
|
inputs[1].sparse().SparseMatrix<Device>(),
|
|
|
|
|
alpha_,
|
|
|
|
@ -356,7 +307,7 @@ public:
|
|
|
|
|
/// matrix = sparse matrix * matrix
|
|
|
|
|
if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
|
|
|
|
|
!outputs[0].isSparseArg()) {
|
|
|
|
|
MulOp<Device>(out_mat,
|
|
|
|
|
MulOp<Device>(outMat,
|
|
|
|
|
inputs[0].sparse().SparseMatrix<Device>(),
|
|
|
|
|
inputs[1].matrix<Device>(),
|
|
|
|
|
alpha_,
|
|
|
|
@ -365,10 +316,10 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// sparse matrix = matrix * matrix
|
|
|
|
|
auto out_sparse_mat = outputs[0].sparse().SparseMatrix<Device>();
|
|
|
|
|
auto outSparseMat = outputs[0].sparse().SparseMatrix<Device>();
|
|
|
|
|
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
|
|
|
|
|
outputs[0].isSparseArg()) {
|
|
|
|
|
MulOp<Device>(out_sparse_mat,
|
|
|
|
|
MulOp<Device>(outSparseMat,
|
|
|
|
|
inputs[0].matrix<Device>(),
|
|
|
|
|
inputs[1].matrix<Device>(),
|
|
|
|
|
alpha_,
|
|
|
|
|