|
|
|
@ -14,8 +14,8 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H
|
|
|
|
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H
|
|
|
|
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H_
|
|
|
|
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H_
|
|
|
|
|
|
|
|
|
|
#include <cublas_v2.h>
|
|
|
|
|
#include <cuda_runtime_api.h>
|
|
|
|
@ -47,8 +47,10 @@ class MatMulGpuKernel : public GpuKernel {
|
|
|
|
|
auto input2_addr = GetDeviceAddress<T>(inputs, 1);
|
|
|
|
|
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
|
|
|
|
|
|
|
|
|
const float alpha = 1;
|
|
|
|
|
const float beta = 0;
|
|
|
|
|
T alpha = static_cast<T>(1.0f);
|
|
|
|
|
T beta = static_cast<T>(0.0f);
|
|
|
|
|
cudaDataType_t compute_type = (dtype_a_ == CUDA_R_64F) ? CUDA_R_64F : CUDA_R_32F;
|
|
|
|
|
|
|
|
|
|
const int lda = (transpose_x1_ == CUBLAS_OP_T) ? SizeToInt(m_) : SizeToInt(k_);
|
|
|
|
|
const int ldb = (transpose_x2_ == CUBLAS_OP_T) ? SizeToInt(k_) : SizeToInt(n_);
|
|
|
|
|
const int ldc = n_;
|
|
|
|
@ -58,20 +60,44 @@ class MatMulGpuKernel : public GpuKernel {
|
|
|
|
|
auto stride_c = SizeToInt(m_ * n_);
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
// Use cublasGemmEx to get high performance when batch_ is 1
|
|
|
|
|
if (batch_ == 1) {
|
|
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
|
|
|
|
|
cublasGemmEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_),
|
|
|
|
|
SizeToInt(k_), &alpha, input2_addr, dtype_b_, ldb, input1_addr,
|
|
|
|
|
dtype_a_, lda, &beta, output_addr, dtype_c_, ldc, CUDA_R_32F, algo_),
|
|
|
|
|
"cublasSgemm Call Fail");
|
|
|
|
|
if (dtype_a_ == CUDA_R_16F) {
|
|
|
|
|
const float alphaf = 1.0f;
|
|
|
|
|
const float betaf = 0.0f;
|
|
|
|
|
// Use cublasGemmEx to get high performance when batch_ is 1
|
|
|
|
|
if (batch_ == 1) {
|
|
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
|
|
|
|
kernel_node_,
|
|
|
|
|
cublasGemmEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), &alphaf,
|
|
|
|
|
input2_addr, dtype_b_, ldb, input1_addr, dtype_a_, lda, &betaf, output_addr, dtype_c_, ldc,
|
|
|
|
|
compute_type, algo_),
|
|
|
|
|
"cublasGemmEx failed");
|
|
|
|
|
} else {
|
|
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
|
|
|
|
kernel_node_,
|
|
|
|
|
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_),
|
|
|
|
|
SizeToInt(k_), &alphaf, input2_addr, dtype_b_, ldb, stride_b, input1_addr,
|
|
|
|
|
dtype_a_, lda, stride_a, &betaf, output_addr, dtype_c_, ldc, stride_c, batch_,
|
|
|
|
|
compute_type, algo_),
|
|
|
|
|
"cublasGemmStridedBatchedEx failed");
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
|
|
|
|
kernel_node_,
|
|
|
|
|
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
|
|
|
|
|
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a,
|
|
|
|
|
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
|
|
|
|
|
"cublasGemmStridedBatchedEx Call Fail");
|
|
|
|
|
// Use cublasGemmEx to get high performance when batch_ is 1
|
|
|
|
|
if (batch_ == 1) {
|
|
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
|
|
|
|
kernel_node_,
|
|
|
|
|
cublasGemmEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), &alpha,
|
|
|
|
|
input2_addr, dtype_b_, ldb, input1_addr, dtype_a_, lda, &beta, output_addr, dtype_c_, ldc,
|
|
|
|
|
compute_type, algo_),
|
|
|
|
|
"cublasGemmEx failed");
|
|
|
|
|
} else {
|
|
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
|
|
|
|
kernel_node_,
|
|
|
|
|
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_),
|
|
|
|
|
SizeToInt(k_), &alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr,
|
|
|
|
|
dtype_a_, lda, stride_a, &beta, output_addr, dtype_c_, ldc, stride_c, batch_,
|
|
|
|
|
compute_type, algo_),
|
|
|
|
|
"cublasGemmStridedBatchedEx failed");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} catch (const std::exception &e) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas "
|
|
|
|
@ -85,6 +111,10 @@ class MatMulGpuKernel : public GpuKernel {
|
|
|
|
|
dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
|
|
|
|
dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1)));
|
|
|
|
|
dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0)));
|
|
|
|
|
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
|
|
|
|
|
if (dtype_a_ != dtype_b_ || dtype_a_ != dtype_c_) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "input and output types are not the same in " << node_name;
|
|
|
|
|
}
|
|
|
|
|
if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) {
|
|
|
|
|
MS_LOG(INFO) << "input and output type is float16, allow to use Tensor Core operations if possible";
|
|
|
|
|
algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
|
|
|
@ -174,4 +204,4 @@ class MatMulGpuKernel : public GpuKernel {
|
|
|
|
|
} // namespace kernel
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|
|
|
|
|
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H
|
|
|
|
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H_
|
|
|
|
|