Pass unit test for SparseCpuMatrix::mul(CpuMatrix, CpuMatrix),

SparseGpuMatrix::mul(GpuMatrix, GpuMatrix),
CpuMatrix::mul(CpuSparseMatrix, CpuMatrix),
and GpuMatrix::mul(GpuSparseMatrix, GpuMatrix)
avx_docs
xutianbing 8 years ago
parent 1ca2846ef6
commit 4751cc8f7e

@ -498,15 +498,10 @@ public:
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
/// todo(tianbing), support SparseMatrixArg for out_mat
auto out_mat = outputs[0].matrix<Device>();
LOG(INFO) << "out_mat:";
out_mat.print(std::cout);
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg()) {
LOG(INFO) << "in1_mat:";
inputs[0].matrix<Device>().print(std::cout);
LOG(INFO) << "in2_mat:";
inputs[1].matrix<Device>().print(std::cout);
/// matrix = matrix * matrix
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) {
MulOp<Device>(out_mat,
inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(),
@ -515,11 +510,9 @@ public:
return;
}
if (!inputs[0].isSparseArg() && inputs[1].isSparseArg()) {
LOG(INFO) << "in1_mat:";
inputs[0].matrix<Device>().print(std::cout);
LOG(INFO) << "in2_mat:";
inputs[1].sparse().SparseMatrix<Device>().print(std::cout);
/// matrix = matrix * sparse matrix
if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) {
MulOp<Device>(out_mat,
inputs[0].matrix<Device>(),
inputs[1].sparse().SparseMatrix<Device>(),
@ -528,11 +521,9 @@ public:
return;
}
if (inputs[0].isSparseArg() && !inputs[1].isSparseArg()) {
LOG(INFO) << "in1_mat:";
inputs[0].sparse().SparseMatrix<Device>().print(std::cout);
LOG(INFO) << "in2_mat:";
inputs[1].matrix<Device>().print(std::cout);
/// matrix = sparse matrix * matrix
if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) {
MulOp<Device>(out_mat,
inputs[0].sparse().SparseMatrix<Device>(),
inputs[1].matrix<Device>(),
@ -540,6 +531,18 @@ public:
beta_);
return;
}
/// sparse matrix = matrix * matrix
auto out_sparse_mat = outputs[0].sparse().SparseMatrix<Device>();
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
outputs[0].isSparseArg()) {
MulOp<Device>(out_sparse_mat,
inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(),
alpha_,
beta_);
return;
}
}
private:

@ -176,7 +176,36 @@ void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
const GpuMatrix& b,
real scale_ab,
real scale_t) {
/// todo(tianbing), implement it
/// todo(tianbing), clean the code
CHECK(a.useGpu_ && b.useGpu_) << "type not match";
CHECK(!out.trans_) << "trans not supported";
real* a_data = const_cast<real*>(a.getData());
real* b_data = const_cast<real*>(b.getData());
hl_sparse_matrix_s out_data = out.sMatrix_.get();
hl_trans_op_t a_trans = a.trans_ ? HPPL_OP_T : HPPL_OP_N;
hl_trans_op_t b_trans = b.trans_ ? HPPL_OP_T : HPPL_OP_N;
if (!a.trans_ && !b.trans_) {
CHECK(out.height_ == a.getHeight());
CHECK(out.width_ == b.getWidth());
CHECK(a.getWidth() == b.getHeight());
} else if (a.trans_ && !b.trans_) {
CHECK(out.height_ == a.getWidth());
CHECK(out.width_ == b.getWidth());
CHECK(a.getHeight() == b.getHeight());
} else if (!a.trans_ && b.trans_) {
CHECK(out.height_ == a.getHeight());
CHECK(out.width_ == b.getHeight());
CHECK(a.getWidth() == b.getWidth());
} else {
LOG(INFO) << "Not support";
}
int dim_m = out.height_;
int dim_n = out.width_;
int dim_k = !b.trans_ ? b.getHeight() : b.getWidth();
hl_sparse_matrix_mul(
a_data, a_trans, b_data, b_trans, out_data,
dim_m, dim_n, dim_k, scale_ab, scale_t);
}
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -177,7 +177,6 @@ GpuSparseMatrix::GpuSparseMatrix(real* value,
hl_sparse_matrix_s_ptr tmp2(tmp, hl_destruct_sparse_matrix);
sMatrix_ = tmp2;
}
LOG(INFO) << "weight to matrix ";
}
}

@ -30,6 +30,17 @@ void checkMatrixEqual(const MatrixPtr& a, const MatrixPtr& b) {
}
}
void checkSMatrixEqual(const CpuSparseMatrix& a, const CpuSparseMatrix& b) {
ASSERT_EQ(a.getWidth(), b.getWidth());
ASSERT_EQ(a.getHeight(), b.getHeight());
ASSERT_EQ(a.isTransposed(), b.isTransposed());
ASSERT_EQ(a.getFormat(), b.getFormat());
ASSERT_EQ(a.getElementCnt(), b.getElementCnt());
for (size_t r = 0; r < a.getElementCnt(); ++r) {
ASSERT_FLOAT_EQ(a.getValue()[r], b.getValue()[r]);
}
}
void checkSMatrixEqual(const CpuSparseMatrixPtr& a,
const CpuSparseMatrixPtr& b) {
ASSERT_EQ(a->getWidth(), b->getWidth());
@ -73,6 +84,36 @@ void checkSMatrixEqual2(const CpuSparseMatrixPtr& a,
}
}
void checkSMatrixEqual2Dense(const CpuSparseMatrix& a, const CpuMatrix& b) {
ASSERT_EQ(a.getWidth(), b.getWidth());
ASSERT_EQ(a.getHeight(), b.getHeight());
ASSERT_EQ(a.isTransposed(), b.isTransposed());
if (a.getFormat() == SPARSE_CSC) {
int* rows = a.getRows();
for (size_t i = 0; i < a.getWidth(); i++) {
for (size_t j = a.getColStartIdx(i); j < a.getColStartIdx(i + 1); j++) {
if (a.getValueType() == FLOAT_VALUE) {
ASSERT_FLOAT_EQ(a.getValue()[j], b.getElement(rows[j], i));
} else {
ASSERT_FLOAT_EQ(1.0, b.getElement(rows[j], i));
}
}
}
} else {
int* cols = a.getCols();
for (size_t i = 0; i < a.getHeight(); i++) {
for (size_t j = a.getRowStartIdx(i); j < a.getRowStartIdx(i + 1); j++) {
if (a.getValueType() == FLOAT_VALUE) {
ASSERT_FLOAT_EQ(a.getValue()[j], b.getElement(i, cols[j]));
} else {
ASSERT_FLOAT_EQ(1.0, b.getElement(i, cols[j]));
}
}
}
}
}
void checkSMatrixEqual2Dense(const CpuSparseMatrixPtr& a,
const CpuMatrixPtr& b) {
ASSERT_EQ(a->getWidth(), b->getWidth());

Loading…
Cancel
Save