update matmul implementation for GPU

pull/11862/head
Corleone 4 years ago
parent 8a61767f32
commit 82de371987

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -58,14 +58,24 @@ class MatMulGpuKernel : public GpuKernel {
auto stride_c = SizeToInt(m_ * n_); auto stride_c = SizeToInt(m_ * n_);
try { try {
CHECK_CUBLAS_RET_WITH_EXCEPT( // Use cublasGemmEx to get high performance when batch_ is 1
kernel_node_, if (batch_ == 1) {
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a, cublasGemmEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_),
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), SizeToInt(k_), &alpha, input2_addr, dtype_b_, ldb, input1_addr,
"cublasSgemm Call Fail"); dtype_a_, lda, &beta, output_addr, dtype_c_, ldc, CUDA_R_32F, algo_),
"cublasSgemm Call Fail");
} else {
CHECK_CUBLAS_RET_WITH_EXCEPT(
kernel_node_,
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a,
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
"cublasGemmStridedBatchedEx Call Fail");
}
} catch (const std::exception &e) { } catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx"; MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas "
<< (batch_ == 1 ? "cublasGemmEx" : "cublasGemmStridedBatchedEx");
} }
return true; return true;
} }

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -150,7 +150,7 @@ def test_tensor_dot_fp16():
network = NetTensorDot(axes) network = NetTensorDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy() ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
np_result = np.tensordot(x1, x2, axes) np_result = np.tensordot(x1, x2, axes)
np.testing.assert_array_almost_equal(ms_result_np, np_result) assert np.allclose(ms_result_np, np_result, rtol=1e-3, atol=1e-3)
# 3D # 3D
shape_x1 = (60, 30, 450) shape_x1 = (60, 30, 450)
@ -164,7 +164,7 @@ def test_tensor_dot_fp16():
network = NetTensorDot(axes) network = NetTensorDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy() ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
np_result = np.tensordot(x1, x2, axes) np_result = np.tensordot(x1, x2, axes)
np.testing.assert_array_almost_equal(ms_result_np, np_result) assert np.allclose(ms_result_np, np_result, rtol=1e-3, atol=6e0)
@pytest.mark.level0 @pytest.mark.level0
@ -173,7 +173,7 @@ def test_tensor_dot_fp16():
def test_tensor_dot_outer(): def test_tensor_dot_outer():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
np.random.seed(2746) np.random.seed(2746)
shape_x1 = (1, 2, 3) # incompatable dims for x1 and x2 shape_x1 = (1, 2, 3) # incompatible dims for x1 and x2
shape_x2 = (4, 5, 6) shape_x2 = (4, 5, 6)
axes = 0 # outer product does not require multiplicable dims axes = 0 # outer product does not require multiplicable dims
x1 = np.random.random(shape_x1).astype(np.float32) x1 = np.random.random(shape_x1).astype(np.float32)

Loading…
Cancel
Save