From 96642a76fdfaceb5d89643b122f36b7e0e5e46fd Mon Sep 17 00:00:00 2001 From: mamba_ni Date: Thu, 6 Aug 2020 10:48:21 +0800 Subject: [PATCH] support cusolver AND OPS cholesky_solve fix bug clang-format format fix --- mindspore/ccsrc/CMakeLists.txt | 3 +- .../gpu/math/cholesky_solve_gpu_kernel.cc | 23 ++ .../gpu/math/cholesky_solve_gpu_kernel.h | 254 ++++++++++++++++++ .../ccsrc/runtime/device/gpu/gpu_common.h | 16 ++ .../runtime/device/gpu/gpu_device_manager.cc | 9 +- .../runtime/device/gpu/gpu_device_manager.h | 4 + mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/_thor_ops.py | 31 +++ 8 files changed, 339 insertions(+), 3 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 1339b4b9ea..56bedb7d67 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -271,7 +271,8 @@ if (ENABLE_GPU) ${CUDA_PATH}/lib64/libcurand.so ${CUDNN_PATH}/lib64/libcudnn.so ${CUDA_PATH}/lib64/libcudart.so - ${CUDA_PATH}/lib64/stubs/libcuda.so) + ${CUDA_PATH}/lib64/stubs/libcuda.so + ${CUDA_PATH}/lib64/libcusolver.so) if (ENABLE_MPI) set_target_properties(_ms_mpi PROPERTIES INSTALL_RPATH ${ORIGIN_PATH}) endif() diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc new file mode 100644 index 0000000000..9ef1429568 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + CholeskyGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h new file mode 100644 index 0000000000..abbbe049d0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h @@ -0,0 +1,254 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H +#define MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace kernel { +template +class CholeskyGpuKernel : public GpuKernel { + public: + CholeskyGpuKernel() : batch_(0), m_(0), lda_(0), is_null_input_(false), handle_(nullptr) {} + ~CholeskyGpuKernel() = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + if (!use_split_matrix) { + auto input1_addr = GetDeviceAddress(inputs, 0); + auto output_addr = GetDeviceAddress(outputs, 0); + auto d_array_addr = GetDeviceAddress(workspace, 0); + auto d_identity_addr = GetDeviceAddress(workspace, 1); + auto d_info_array_addr = GetDeviceAddress(workspace, 2); + for (size_t i = 0; i < batch_; i++) { + h_array[i] = input1_addr + i * lda_ * m_; + h_identity[i] = output_addr + i * ldb_ * m_; + CHECK_CUDA_RET_WITH_ERROR( + cudaMemcpyAsync(output_addr + i * ldb_ * m_, h_identity_data.data(), sizeof(T) * ldb_ * m_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + } + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUSOLVER_RET_WITH_EXCEPT( + cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_), + "cusolver cholesky batched Fail"); + float alpha = 1; + CHECK_CUBLAS_RET_WITH_EXCEPT( + cublasStrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, m_, m_, &alpha, + d_array_addr, lda_, d_identity_addr, ldb_, batch_), + "cublas trsm batched Fail"); + } else { + auto input1_addr = GetDeviceAddress(inputs, 0); + auto output_addr = GetDeviceAddress(outputs, 0); + auto d_array_addr = GetDeviceAddress(workspace, 0); + auto d_identity_addr = GetDeviceAddress(workspace, 1); + auto d_info_array_addr = GetDeviceAddress(workspace, 2); + auto d_batch_input_addr = GetDeviceAddress(workspace, 3); + for (size_t i = 0; i < batch_; i++) { + h_array[i] = d_batch_input_addr + i * lda_ * m_; + h_identity[i] = output_addr + i * ldb_ * m_; + } + Identity(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast(stream_ptr)); + MatrixSplit(batch_ * split_dim * split_dim, split_dim, width, input1_addr, d_batch_input_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUSOLVER_RET_WITH_EXCEPT( + cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_), + "cusolver cholesky batched Fail"); + float alpha = 1; + CHECK_CUBLAS_RET_WITH_EXCEPT( + cublasStrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, m_, m_, &alpha, + d_array_addr, lda_, d_identity_addr, ldb_, batch_), + "cublas trsm batched Fail"); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle(); + blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); + auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + split_dim = GetAttr(kernel_node, "split_dim"); + if (split_dim == 0) { + use_split_matrix = false; + if (in_shape.size() == 2) { + batch_ = 1; + if (in_shape[0] != in_shape[1]) { + MS_LOG(ERROR) << "Cholesky need square matrix as input."; + } + } else if (in_shape.size() == 3) { + batch_ = SizeToInt(in_shape[0]); + if (in_shape[1] != in_shape[2]) { + MS_LOG(ERROR) << "Cholesky need square matrix as input."; + } + } else { + MS_LOG(ERROR) << "Input Only support Rank 2 OR 3"; + } + + m_ = SizeToInt(in_shape[1]); + lda_ = m_; + ldb_ = m_; + h_array.resize(batch_); + h_identity.resize(batch_); + h_identity_data.resize(m_ * m_); + for (size_t i = 0; i < m_; i++) { + for (size_t j = 0; j < m_; j++) { + if (i == j) { + h_identity_data[i * m_ + j] = 1; + } else { + h_identity_data[i * m_ + j] = 0; + } + } + } + InitSizeLists(); + } else { + if (in_shape.size() != 2) { + MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2."; + } + height = in_shape[0]; + width = in_shape[1]; + if (height != width) { + MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input."; + } + if (SizeToInt(height) <= split_dim) { + use_split_matrix = false; + batch_ = 1; + m_ = SizeToInt(in_shape[1]); + lda_ = m_; + ldb_ = m_; + h_array.resize(batch_); + h_identity.resize(batch_); + h_identity_data.resize(m_ * m_); + for (size_t i = 0; i < m_; i++) { + for (size_t j = 0; j < m_; j++) { + if (i == j) { + h_identity_data[i * m_ + j] = 1; + } else { + h_identity_data[i * m_ + j] = 0; + } + } + } + InitSizeLists(); + } else { + use_split_matrix = true; + int batch = SizeToInt(in_shape[1]) / split_dim; + res_dim = in_shape[1] - batch * split_dim; + if (res_dim == 0) { + batch_ = batch; + } else { + batch_ = batch + 1; + } + m_ = split_dim; + lda_ = m_; + ldb_ = m_; + h_array.resize(batch_); + h_identity.resize(batch_); + h_identity_data.resize(m_ * m_); + for (size_t i = 0; i < m_; i++) { + for (size_t j = 0; j < m_; j++) { + if (i == j) { + h_identity_data[i * m_ + j] = 1; + } else { + h_identity_data[i * m_ + j] = 0; + } + } + } + InitSizeLists(); + } + } + return true; + } + + protected: + void InitSizeLists() override { + if (!use_split_matrix) { + size_t unit_size = sizeof(T); + size_t input_size = batch_ * m_ * lda_ * unit_size; + input_size_list_.push_back(input_size); + size_t output_size = batch_ * m_ * lda_ * unit_size; + output_size_list_.push_back(output_size); + size_t workspace_size = batch_ * sizeof(T *); + workspace_size_list_.push_back(workspace_size); + workspace_size = batch_ * sizeof(T *); + workspace_size_list_.push_back(workspace_size); + workspace_size = batch_ * sizeof(int); + workspace_size_list_.push_back(workspace_size); + } else { + size_t unit_size = sizeof(T); + size_t input_size = height * width * unit_size; + input_size_list_.push_back(input_size); + size_t output_size = batch_ * m_ * lda_ * unit_size; + output_size_list_.push_back(output_size); + size_t workspace_size = batch_ * sizeof(T *); + workspace_size_list_.push_back(workspace_size); + workspace_size = batch_ * sizeof(T *); + workspace_size_list_.push_back(workspace_size); + workspace_size = batch_ * sizeof(int); + workspace_size_list_.push_back(workspace_size); + workspace_size = batch_ * m_ * lda_ * unit_size; + workspace_size_list_.push_back(workspace_size); + } + } + + private: + size_t batch_; + size_t m_; + size_t lda_; + size_t ldb_; + int res_dim; + int split_dim; + bool is_null_input_; + bool use_split_matrix; + size_t height; + size_t width; + cusolverDnHandle_t handle_; + cublasHandle_t blas_handle_; + cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; + std::vector h_array; + std::vector h_identity; + std::vector h_identity_data; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_common.h b/mindspore/ccsrc/runtime/device/gpu/gpu_common.h index 6ed682e704..ea2b321714 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_common.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_common.h @@ -93,6 +93,22 @@ namespace gpu { } \ } +#define CHECK_CUSOLVER_RET_WITH_EXCEPT(expression, message) \ + { \ + cusolverStatus_t status = (expression); \ + if (status != CUSOLVER_STATUS_SUCCESS) { \ + MS_LOG(EXCEPTION) << "cusolver Error: " << message << " | Error Number: " << status; \ + } \ + } + +#define CHECK_CUSOLVER_RET_WITH_ERROR(expression, message) \ + { \ + cusolverStatus_t status = (expression); \ + if (status != CUSOLVER_STATUS_SUCCESS) { \ + MS_LOG(ERROR) << "cusolver Error: " << message << " | Error Number: " << status; \ + } \ + } + #define CHECK_NCCL_RET_WITH_EXCEPT(expression, message) \ { \ int result = (expression); \ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc index 8f17fc20b5..5207bdf1b6 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc @@ -32,6 +32,10 @@ void GPUDeviceManager::InitDevice() { CHECK_CUBLAS_RET_WITH_EXCEPT(cublasCreate(&cublas_handle_), "Failed to create cuBLAS handle."); CHECK_CUBLAS_RET_WITH_EXCEPT(cublasSetStream(cublas_handle_, reinterpret_cast(default_stream())), "Failed to set stream for cuBLAS handle."); + CHECK_CUSOLVER_RET_WITH_EXCEPT(cusolverDnCreate(&cusolver_dn_handle_), "Failed to create cusolver dn handle."); + CHECK_CUSOLVER_RET_WITH_EXCEPT( + cusolverDnSetStream(cusolver_dn_handle_, reinterpret_cast(default_stream())), + "Failed to set stream for cusolver dn handle"); CHECK_OP_RET_WITH_EXCEPT(GPUMemoryAllocator::GetInstance().Init(), "Failed to Init gpu memory allocator") } @@ -47,6 +51,9 @@ void GPUDeviceManager::ReleaseDevice() { if (cublas_handle_ != nullptr) { CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle."); } + if (cusolver_dn_handle_ != nullptr) { + CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnDestroy(cusolver_dn_handle_), "Failed to destroy cusolver dn handle."); + } CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); } @@ -79,7 +86,7 @@ bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } - +const cusolverDnHandle_t &GPUDeviceManager::GetCusolverDnHandle() const { return cusolver_dn_handle_; } bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); } bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h index 69f33d41c4..b2bb618621 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h @@ -19,6 +19,7 @@ #include #include +#include #include #include #include "runtime/device/gpu/cuda_driver.h" @@ -43,6 +44,7 @@ class GPUDeviceManager { const cudnnHandle_t &GetCudnnHandle() const; const cublasHandle_t &GetCublasHandle() const; + const cusolverDnHandle_t &GetCusolverDnHandle() const; bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; @@ -73,6 +75,8 @@ class GPUDeviceManager { // handle used for cuBLAS kernels. cublasHandle_t cublas_handle_{nullptr}; + // handle used for cusolver dn kernels; + cusolverDnHandle_t cusolver_dn_handle_{nullptr}; bool dev_id_init_; uint32_t cur_dev_id_; }; diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 1c76b737c7..f3c3ee04e2 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -86,7 +86,7 @@ from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, Popul from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, CusMatMulCubeDenseRight, - CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient) + CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky) from .sparse_ops import SparseToDense __all__ = [ diff --git a/mindspore/ops/operations/_thor_ops.py b/mindspore/ops/operations/_thor_ops.py index 178e18414e..a4f2335c9b 100644 --- a/mindspore/ops/operations/_thor_ops.py +++ b/mindspore/ops/operations/_thor_ops.py @@ -607,3 +607,34 @@ class UpdateThorGradient(PrimitiveWithInfer): validator.check_tensor_type_same({'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype}, [mstype.float32], self.name) return x2_dtype + +class Cholesky(PrimitiveWithInfer): + """ + Inner API for resnet50 THOR GPU backend + """ + @prim_attr_register + def __init__(self, split_dim=0): + self.init_prim_io_names(inputs=['x1'], outputs=['y']) + self.split_dim = split_dim + self.add_prim_attr('split_dim', self.split_dim) + + def infer_shape(self, x1_shape): + if self.split_dim != 0: + assert len(x1_shape) == 2 + height = x1_shape[0] + width = x1_shape[1] + assert height == width + if height <= self.split_dim: + out_shape = [1, height, width] + else: + batch = height // self.split_dim + if height != batch * self.split_dim: + batch += 1 + out_shape = [batch, self.split_dim, self.split_dim] + else: + out_shape = x1_shape + return out_shape + + def infer_dtype(self, x1_dtype): + validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) + return x1_dtype