|
|
|
@ -163,17 +163,20 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
|
|
|
|
|
if (trans_y) {
|
|
|
|
|
const int M = Y->numel() / N;
|
|
|
|
|
VLOG(3) << "MatMul's case 2";
|
|
|
|
|
blas.GEMV(false, M, N, 1., y_data, x_data, 0., Out->data<T>());
|
|
|
|
|
blas.GEMV(false, M, N, static_cast<T>(1), y_data, x_data,
|
|
|
|
|
static_cast<T>(0), Out->data<T>());
|
|
|
|
|
} else {
|
|
|
|
|
const int M = y_dims[y_ndim - 1];
|
|
|
|
|
const int batch_size = Y->numel() / (M * N);
|
|
|
|
|
if (batch_size == 1) {
|
|
|
|
|
VLOG(3) << "MatMul's case 3";
|
|
|
|
|
blas.GEMV(true, N, M, 1., y_data, x_data, 0., Out->data<T>());
|
|
|
|
|
blas.GEMV(true, N, M, static_cast<T>(1), y_data, x_data,
|
|
|
|
|
static_cast<T>(0), Out->data<T>());
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << "MatMul's case 4";
|
|
|
|
|
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, 1.0f, y_data,
|
|
|
|
|
x_data, 0, Out->data<T>(), batch_size, M * N, 0);
|
|
|
|
|
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast<T>(1),
|
|
|
|
|
y_data, x_data, static_cast<T>(0), Out->data<T>(),
|
|
|
|
|
batch_size, M * N, 0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
@ -205,16 +208,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
|
|
|
|
|
const int batch_size = X->numel() / (M * N);
|
|
|
|
|
if (batch_size == 1) {
|
|
|
|
|
VLOG(3) << "MatMul's case 5";
|
|
|
|
|
blas.GEMV(true, N, M, 1.0f, x_data, y_data, 0.0f, Out->data<T>());
|
|
|
|
|
blas.GEMV(true, N, M, static_cast<T>(1), x_data, y_data,
|
|
|
|
|
static_cast<T>(0), Out->data<T>());
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << "MatMul's case 6";
|
|
|
|
|
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, 1.0f, x_data,
|
|
|
|
|
y_data, 0, Out->data<T>(), batch_size, M * N, 0);
|
|
|
|
|
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast<T>(1),
|
|
|
|
|
x_data, y_data, static_cast<T>(0), Out->data<T>(),
|
|
|
|
|
batch_size, M * N, 0);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
const int M = X->numel() / N;
|
|
|
|
|
VLOG(3) << "MatMul's case 7";
|
|
|
|
|
blas.GEMV(false, M, N, 1.0f, x_data, y_data, 0.0f, Out->data<T>());
|
|
|
|
|
blas.GEMV(false, M, N, static_cast<T>(1), x_data, y_data,
|
|
|
|
|
static_cast<T>(0), Out->data<T>());
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -263,37 +269,38 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
|
|
|
|
|
if (x_batch_size == 1 && y_batch_size == 1) {
|
|
|
|
|
VLOG(3) << "MatMul's case 8";
|
|
|
|
|
blas.GEMM(trans_x ? CblasTrans : CblasNoTrans,
|
|
|
|
|
trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data,
|
|
|
|
|
y_data, 0.0f, Out->data<T>());
|
|
|
|
|
trans_y ? CblasTrans : CblasNoTrans, M, N, K, static_cast<T>(1),
|
|
|
|
|
x_data, y_data, static_cast<T>(0), Out->data<T>());
|
|
|
|
|
} else if (x_batch_size == 1) {
|
|
|
|
|
if (M == 1 && trans_y) {
|
|
|
|
|
VLOG(3) << "MatMul's case 9";
|
|
|
|
|
blas.GEMV(false, y_batch_size * N, K, 1.0f, y_data, x_data, 0.0f,
|
|
|
|
|
Out->data<T>());
|
|
|
|
|
blas.GEMV(false, y_batch_size * N, K, static_cast<T>(1), y_data, x_data,
|
|
|
|
|
static_cast<T>(0), Out->data<T>());
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << "MatMul's case 10";
|
|
|
|
|
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
|
|
|
|
|
trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f,
|
|
|
|
|
x_data, y_data, 0, Out->data<T>(), out_batch_size, 0,
|
|
|
|
|
K * N);
|
|
|
|
|
trans_y ? CblasTrans : CblasNoTrans, M, N, K,
|
|
|
|
|
static_cast<T>(1), x_data, y_data, static_cast<T>(0),
|
|
|
|
|
Out->data<T>(), out_batch_size, 0, K * N);
|
|
|
|
|
}
|
|
|
|
|
} else if (y_batch_size == 1) {
|
|
|
|
|
if (!trans_x) {
|
|
|
|
|
VLOG(3) << "MatMul's case 11";
|
|
|
|
|
blas.GEMM(CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans,
|
|
|
|
|
x_batch_size * M, N, K, 1.0f, x_data, y_data, 0.0f,
|
|
|
|
|
Out->data<T>());
|
|
|
|
|
x_batch_size * M, N, K, static_cast<T>(1), x_data, y_data,
|
|
|
|
|
static_cast<T>(0), Out->data<T>());
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << "MatMul's case 12";
|
|
|
|
|
blas.BatchedGEMM(CblasTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K,
|
|
|
|
|
1.0f, x_data, y_data, 0, Out->data<T>(), out_batch_size,
|
|
|
|
|
M * K, 0);
|
|
|
|
|
static_cast<T>(1), x_data, y_data, static_cast<T>(0),
|
|
|
|
|
Out->data<T>(), out_batch_size, M * K, 0);
|
|
|
|
|
}
|
|
|
|
|
} else if (!is_broadcast_dims) {
|
|
|
|
|
VLOG(3) << "MatMul's case 13";
|
|
|
|
|
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
|
|
|
|
|
trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data,
|
|
|
|
|
y_data, 0, Out->data<T>(), out_batch_size, M * K, K * N);
|
|
|
|
|
trans_y ? CblasTrans : CblasNoTrans, M, N, K,
|
|
|
|
|
static_cast<T>(1), x_data, y_data, static_cast<T>(0),
|
|
|
|
|
Out->data<T>(), out_batch_size, M * K, K * N);
|
|
|
|
|
} else {
|
|
|
|
|
// in the case, can't use stridedgemm
|
|
|
|
|
std::vector<const T*> x_ptr(out_batch_size);
|
|
|
|
@ -314,9 +321,9 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "MatMul's case 14";
|
|
|
|
|
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
|
|
|
|
|
trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f,
|
|
|
|
|
x_ptr.data(), y_ptr.data(), 0.0f, out_ptr.data(),
|
|
|
|
|
out_batch_size);
|
|
|
|
|
trans_y ? CblasTrans : CblasNoTrans, M, N, K,
|
|
|
|
|
static_cast<T>(1), x_ptr.data(), y_ptr.data(),
|
|
|
|
|
static_cast<T>(0), out_ptr.data(), out_batch_size);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|