|
|
|
@ -27,38 +27,22 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
|
|
|
|
|
const GpuMatrix& a,
|
|
|
|
|
const GpuMatrix& b,
|
|
|
|
|
real scaleAB,
|
|
|
|
|
real scaleT) {
|
|
|
|
|
CHECK(!out.isTransposed()) << "Transpose not supported for out matrix";
|
|
|
|
|
if (!a.isTransposed() && !b.isTransposed()) {
|
|
|
|
|
/// a : M * K, b: K * N
|
|
|
|
|
CHECK(out.getWidth() == b.getWidth() &&
|
|
|
|
|
out.getHeight() == a.getHeight() &&
|
|
|
|
|
a.getWidth() == b.getHeight());
|
|
|
|
|
} else if (a.isTransposed() && !b.isTransposed()) {
|
|
|
|
|
/// a : K * M, b : K * N
|
|
|
|
|
CHECK(out.getWidth() == b.getWidth() &&
|
|
|
|
|
out.getHeight() == a.getWidth() &&
|
|
|
|
|
a.getHeight() == b.getHeight());
|
|
|
|
|
} else if (!a.isTransposed() && b.isTransposed()) {
|
|
|
|
|
/// a: M * K, b : N * K
|
|
|
|
|
CHECK(out.getWidth() == b.getHeight() &&
|
|
|
|
|
out.getHeight() == a.getHeight() &&
|
|
|
|
|
a.getWidth() == b.getWidth());
|
|
|
|
|
} else {
|
|
|
|
|
LOG(FATAL) << "Not support for both a and b are Transposed Matrices";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
real scaleT,
|
|
|
|
|
bool aTrans,
|
|
|
|
|
bool bTrans,
|
|
|
|
|
bool cTrans) {
|
|
|
|
|
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
|
|
|
|
|
real* aData = const_cast<real*>(a.getData());
|
|
|
|
|
real* bData = const_cast<real*>(b.getData());
|
|
|
|
|
real* outData = const_cast<real*>(out.getData());
|
|
|
|
|
hl_matrix_mul(aData,
|
|
|
|
|
!a.isTransposed() ? HPPL_OP_N : HPPL_OP_T,
|
|
|
|
|
!aTrans ? HPPL_OP_N : HPPL_OP_T,
|
|
|
|
|
bData,
|
|
|
|
|
!b.isTransposed() ? HPPL_OP_N : HPPL_OP_T,
|
|
|
|
|
!bTrans ? HPPL_OP_N : HPPL_OP_T,
|
|
|
|
|
outData,
|
|
|
|
|
out.getHeight(),
|
|
|
|
|
out.getWidth(),
|
|
|
|
|
!a.isTransposed() ? a.getWidth() : a.getHeight(),
|
|
|
|
|
!aTrans ? a.getWidth() : a.getHeight(),
|
|
|
|
|
scaleAB,
|
|
|
|
|
scaleT,
|
|
|
|
|
a.getStride(),
|
|
|
|
@ -75,27 +59,19 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
|
|
|
|
|
const GpuSparseMatrix& a,
|
|
|
|
|
const GpuMatrix& b,
|
|
|
|
|
real scaleAB,
|
|
|
|
|
real scaleT) {
|
|
|
|
|
real scaleT,
|
|
|
|
|
bool aTrans,
|
|
|
|
|
bool bTrans,
|
|
|
|
|
bool cTrans) {
|
|
|
|
|
CHECK(out.isContiguous());
|
|
|
|
|
CHECK(b.isContiguous());
|
|
|
|
|
CHECK(b.useGpu_) << "Matrix type are not equal";
|
|
|
|
|
CHECK(!out.isTransposed() && !b.isTransposed()) << "not supported";
|
|
|
|
|
if (!a.isTransposed()) {
|
|
|
|
|
/// a: M * K, b: K * N
|
|
|
|
|
CHECK(out.getWidth() == b.getWidth() && out.getHeight() == a.getHeight()
|
|
|
|
|
&& a.getWidth() == b.getHeight()) << "Matrix dimensions are not equal";
|
|
|
|
|
} else {
|
|
|
|
|
/// a: K * M, transpose, b: K * N
|
|
|
|
|
CHECK(out.getWidth() == b.getWidth() && out.getHeight() == a.getWidth()
|
|
|
|
|
&& a.getHeight() == b.getHeight()) << "Matrix dimensions are not equal";
|
|
|
|
|
}
|
|
|
|
|
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
|
|
|
|
|
|
|
|
|
|
hl_trans_op_t aTrans = a.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
|
|
|
|
|
hl_sparse_matrix_s aData = a.sMatrix_.get();
|
|
|
|
|
real* bData = const_cast<real*>(b.getData());
|
|
|
|
|
real* outData = const_cast<real*>(out.getData());
|
|
|
|
|
hl_matrix_csr_mul_dense(aData,
|
|
|
|
|
aTrans,
|
|
|
|
|
aTrans ? HPPL_OP_T : HPPL_OP_N,
|
|
|
|
|
bData,
|
|
|
|
|
HPPL_OP_N,
|
|
|
|
|
outData,
|
|
|
|
@ -115,25 +91,14 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
|
|
|
|
|
const GpuMatrix& a,
|
|
|
|
|
const GpuSparseMatrix& b,
|
|
|
|
|
real scaleAB,
|
|
|
|
|
real scaleT) {
|
|
|
|
|
real scaleT,
|
|
|
|
|
bool aTrans,
|
|
|
|
|
bool bTrans,
|
|
|
|
|
bool cTrans) {
|
|
|
|
|
CHECK(out.isContiguous());
|
|
|
|
|
CHECK(a.isContiguous());
|
|
|
|
|
CHECK(a.useGpu_) << "Matrix type are not equal";
|
|
|
|
|
if (!b.isTransposed()) {
|
|
|
|
|
/// a : M * K, b : K * N
|
|
|
|
|
CHECK(out.getWidth() == b.getWidth() &&
|
|
|
|
|
out.getHeight() == a.getHeight() &&
|
|
|
|
|
a.getWidth() == b.getHeight())
|
|
|
|
|
<< "Matrix dimensions are not equal";
|
|
|
|
|
} else {
|
|
|
|
|
/// a : M * K, b : N * K, transpose
|
|
|
|
|
CHECK(out.getWidth() == b.getHeight() &&
|
|
|
|
|
out.getHeight() == a.getHeight() &&
|
|
|
|
|
a.getWidth() == b.getWidth())
|
|
|
|
|
<< "Matrix dimensions are not equal";
|
|
|
|
|
}
|
|
|
|
|
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
|
|
|
|
|
|
|
|
|
|
hl_trans_op_t bTrans = b.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
|
|
|
|
|
hl_sparse_matrix_s bData = b.sMatrix_.get();
|
|
|
|
|
real* aData = const_cast<real*>(a.getData());
|
|
|
|
|
real* outData = const_cast<real*>(out.getData());
|
|
|
|
@ -142,7 +107,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
|
|
|
|
|
hl_matrix_dense_mul_csc(aData,
|
|
|
|
|
HPPL_OP_N,
|
|
|
|
|
bData,
|
|
|
|
|
bTrans,
|
|
|
|
|
bTrans ? HPPL_OP_T : HPPL_OP_N,
|
|
|
|
|
outData,
|
|
|
|
|
out.getHeight(),
|
|
|
|
|
out.getWidth(),
|
|
|
|
@ -153,7 +118,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
|
|
|
|
|
hl_matrix_dense_mul_csr(aData,
|
|
|
|
|
HPPL_OP_N,
|
|
|
|
|
bData,
|
|
|
|
|
bTrans,
|
|
|
|
|
bTrans ? HPPL_OP_T : HPPL_OP_N,
|
|
|
|
|
outData,
|
|
|
|
|
out.getHeight(),
|
|
|
|
|
out.getWidth(),
|
|
|
|
@ -168,35 +133,26 @@ void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
|
|
|
|
|
const GpuMatrix& a,
|
|
|
|
|
const GpuMatrix& b,
|
|
|
|
|
real scaleAB,
|
|
|
|
|
real scaleT) {
|
|
|
|
|
real scaleT,
|
|
|
|
|
bool aTrans,
|
|
|
|
|
bool bTrans,
|
|
|
|
|
bool cTrans) {
|
|
|
|
|
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
|
|
|
|
|
CHECK(!out.isTransposed()) << "Transpose is not supported for out matrix";
|
|
|
|
|
|
|
|
|
|
if (!a.isTransposed() && !b.isTransposed()) {
|
|
|
|
|
CHECK(out.getHeight() == a.getHeight() &&
|
|
|
|
|
out.getWidth() == b.getWidth() &&
|
|
|
|
|
a.getWidth() == b.getHeight());
|
|
|
|
|
} else if (a.isTransposed() && !b.isTransposed()) {
|
|
|
|
|
CHECK(out.getHeight() == a.getWidth() &&
|
|
|
|
|
out.getWidth() == b.getWidth() &&
|
|
|
|
|
a.getHeight() == b.getHeight());
|
|
|
|
|
} else if (!a.isTransposed() && b.isTransposed()) {
|
|
|
|
|
CHECK(out.getHeight() == a.getHeight() &&
|
|
|
|
|
out.getWidth() == b.getHeight() &&
|
|
|
|
|
a.getWidth() == b.getWidth());
|
|
|
|
|
} else {
|
|
|
|
|
LOG(FATAL) << "Not support for both a and b are Transposed Matrices";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
hl_trans_op_t aTrans = a.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
|
|
|
|
|
hl_trans_op_t bTrans = b.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
|
|
|
|
|
int dimK = !b.isTransposed() ? b.getHeight() : b.getWidth();
|
|
|
|
|
real* aData = const_cast<real*>(a.getData());
|
|
|
|
|
real* bData = const_cast<real*>(b.getData());
|
|
|
|
|
hl_sparse_matrix_s outData = out.sMatrix_.get();
|
|
|
|
|
|
|
|
|
|
hl_sparse_matrix_mul(aData, aTrans, bData, bTrans, outData,
|
|
|
|
|
out.getHeight(), out.getWidth(), dimK, scaleAB, scaleT);
|
|
|
|
|
hl_sparse_matrix_mul(aData,
|
|
|
|
|
aTrans ? HPPL_OP_T : HPPL_OP_N,
|
|
|
|
|
bData,
|
|
|
|
|
bTrans ? HPPL_OP_T : HPPL_OP_N,
|
|
|
|
|
outData,
|
|
|
|
|
out.getHeight(),
|
|
|
|
|
out.getWidth(),
|
|
|
|
|
!bTrans ? b.getHeight() : b.getWidth(),
|
|
|
|
|
scaleAB,
|
|
|
|
|
scaleT);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace paddle
|
|
|
|
|