|
|
|
@ -498,15 +498,10 @@ public:
|
|
|
|
|
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
|
|
|
|
|
|
|
|
|
|
/// todo(tianbing), support SparseMatrixArg for out_mat
|
|
|
|
|
auto out_mat = outputs[0].matrix<Device>();
|
|
|
|
|
LOG(INFO) << "out_mat:";
|
|
|
|
|
out_mat.print(std::cout);
|
|
|
|
|
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg()) {
|
|
|
|
|
LOG(INFO) << "in1_mat:";
|
|
|
|
|
inputs[0].matrix<Device>().print(std::cout);
|
|
|
|
|
LOG(INFO) << "in2_mat:";
|
|
|
|
|
inputs[1].matrix<Device>().print(std::cout);
|
|
|
|
|
/// matrix = matrix * matrix
|
|
|
|
|
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
|
|
|
|
|
!outputs[0].isSparseArg()) {
|
|
|
|
|
MulOp<Device>(out_mat,
|
|
|
|
|
inputs[0].matrix<Device>(),
|
|
|
|
|
inputs[1].matrix<Device>(),
|
|
|
|
@ -515,11 +510,9 @@ public:
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!inputs[0].isSparseArg() && inputs[1].isSparseArg()) {
|
|
|
|
|
LOG(INFO) << "in1_mat:";
|
|
|
|
|
inputs[0].matrix<Device>().print(std::cout);
|
|
|
|
|
LOG(INFO) << "in2_mat:";
|
|
|
|
|
inputs[1].sparse().SparseMatrix<Device>().print(std::cout);
|
|
|
|
|
/// matrix = matrix * sparse matrix
|
|
|
|
|
if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
|
|
|
|
|
!outputs[0].isSparseArg()) {
|
|
|
|
|
MulOp<Device>(out_mat,
|
|
|
|
|
inputs[0].matrix<Device>(),
|
|
|
|
|
inputs[1].sparse().SparseMatrix<Device>(),
|
|
|
|
@ -528,11 +521,9 @@ public:
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (inputs[0].isSparseArg() && !inputs[1].isSparseArg()) {
|
|
|
|
|
LOG(INFO) << "in1_mat:";
|
|
|
|
|
inputs[0].sparse().SparseMatrix<Device>().print(std::cout);
|
|
|
|
|
LOG(INFO) << "in2_mat:";
|
|
|
|
|
inputs[1].matrix<Device>().print(std::cout);
|
|
|
|
|
/// matrix = sparse matrix * matrix
|
|
|
|
|
if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
|
|
|
|
|
!outputs[0].isSparseArg()) {
|
|
|
|
|
MulOp<Device>(out_mat,
|
|
|
|
|
inputs[0].sparse().SparseMatrix<Device>(),
|
|
|
|
|
inputs[1].matrix<Device>(),
|
|
|
|
@ -540,6 +531,18 @@ public:
|
|
|
|
|
beta_);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// sparse matrix = matrix * matrix
|
|
|
|
|
auto out_sparse_mat = outputs[0].sparse().SparseMatrix<Device>();
|
|
|
|
|
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
|
|
|
|
|
outputs[0].isSparseArg()) {
|
|
|
|
|
MulOp<Device>(out_sparse_mat,
|
|
|
|
|
inputs[0].matrix<Device>(),
|
|
|
|
|
inputs[1].matrix<Device>(),
|
|
|
|
|
alpha_,
|
|
|
|
|
beta_);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|