|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
/**
|
|
|
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
|
|
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
|
|
|
|
*
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
@ -58,14 +58,24 @@ class MatMulGpuKernel : public GpuKernel {
|
|
|
|
|
auto stride_c = SizeToInt(m_ * n_);
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
|
|
|
|
kernel_node_,
|
|
|
|
|
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
|
|
|
|
|
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a,
|
|
|
|
|
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
|
|
|
|
|
"cublasSgemm Call Fail");
|
|
|
|
|
// Use cublasGemmEx to get high performance when batch_ is 1
|
|
|
|
|
if (batch_ == 1) {
|
|
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
|
|
|
|
|
cublasGemmEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_),
|
|
|
|
|
SizeToInt(k_), &alpha, input2_addr, dtype_b_, ldb, input1_addr,
|
|
|
|
|
dtype_a_, lda, &beta, output_addr, dtype_c_, ldc, CUDA_R_32F, algo_),
|
|
|
|
|
"cublasSgemm Call Fail");
|
|
|
|
|
} else {
|
|
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
|
|
|
|
kernel_node_,
|
|
|
|
|
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
|
|
|
|
|
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a,
|
|
|
|
|
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
|
|
|
|
|
"cublasGemmStridedBatchedEx Call Fail");
|
|
|
|
|
}
|
|
|
|
|
} catch (const std::exception &e) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas "
|
|
|
|
|
<< (batch_ == 1 ? "cublasGemmEx" : "cublasGemmStridedBatchedEx");
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|