!4038 support ops cholesky for resnet50 thor gpu

Merge pull request !4038 from mamba_ni/master
pull/4038/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit ba3a2976dc

@ -271,7 +271,8 @@ if (ENABLE_GPU)
${CUDA_PATH}/lib64/libcurand.so
${CUDNN_PATH}/lib64/libcudnn.so
${CUDA_PATH}/lib64/libcudart.so
${CUDA_PATH}/lib64/stubs/libcuda.so)
${CUDA_PATH}/lib64/stubs/libcuda.so
${CUDA_PATH}/lib64/libcusolver.so)
if (ENABLE_MPI)
set_target_properties(_ms_mpi PROPERTIES INSTALL_RPATH ${ORIGIN_PATH})
endif()

@ -0,0 +1,23 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CholeskyGpuKernel, float)
} // namespace kernel
} // namespace mindspore

@ -93,6 +93,22 @@ namespace gpu {
} \
}
#define CHECK_CUSOLVER_RET_WITH_EXCEPT(expression, message) \
{ \
cusolverStatus_t status = (expression); \
if (status != CUSOLVER_STATUS_SUCCESS) { \
MS_LOG(EXCEPTION) << "cusolver Error: " << message << " | Error Number: " << status; \
} \
}
#define CHECK_CUSOLVER_RET_WITH_ERROR(expression, message) \
{ \
cusolverStatus_t status = (expression); \
if (status != CUSOLVER_STATUS_SUCCESS) { \
MS_LOG(ERROR) << "cusolver Error: " << message << " | Error Number: " << status; \
} \
}
#define CHECK_NCCL_RET_WITH_EXCEPT(expression, message) \
{ \
int result = (expression); \

@ -32,6 +32,10 @@ void GPUDeviceManager::InitDevice() {
CHECK_CUBLAS_RET_WITH_EXCEPT(cublasCreate(&cublas_handle_), "Failed to create cuBLAS handle.");
CHECK_CUBLAS_RET_WITH_EXCEPT(cublasSetStream(cublas_handle_, reinterpret_cast<cudaStream_t>(default_stream())),
"Failed to set stream for cuBLAS handle.");
CHECK_CUSOLVER_RET_WITH_EXCEPT(cusolverDnCreate(&cusolver_dn_handle_), "Failed to create cusolver dn handle.");
CHECK_CUSOLVER_RET_WITH_EXCEPT(
cusolverDnSetStream(cusolver_dn_handle_, reinterpret_cast<cudaStream_t>(default_stream())),
"Failed to set stream for cusolver dn handle");
CHECK_OP_RET_WITH_EXCEPT(GPUMemoryAllocator::GetInstance().Init(), "Failed to Init gpu memory allocator")
}
@ -47,6 +51,9 @@ void GPUDeviceManager::ReleaseDevice() {
if (cublas_handle_ != nullptr) {
CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle.");
}
if (cusolver_dn_handle_ != nullptr) {
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnDestroy(cusolver_dn_handle_), "Failed to destroy cusolver dn handle.");
}
CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator");
}
@ -79,7 +86,7 @@ bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; }
const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; }
const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; }
const cusolverDnHandle_t &GPUDeviceManager::GetCusolverDnHandle() const { return cusolver_dn_handle_; }
bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); }
bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const {

@ -19,6 +19,7 @@
#include <cudnn.h>
#include <cublas_v2.h>
#include <cusolverDn.h>
#include <vector>
#include <memory>
#include "runtime/device/gpu/cuda_driver.h"
@ -43,6 +44,7 @@ class GPUDeviceManager {
const cudnnHandle_t &GetCudnnHandle() const;
const cublasHandle_t &GetCublasHandle() const;
const cusolverDnHandle_t &GetCusolverDnHandle() const;
bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const;
bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const;
@ -73,6 +75,8 @@ class GPUDeviceManager {
// handle used for cuBLAS kernels.
cublasHandle_t cublas_handle_{nullptr};
// handle used for cusolver dn kernels;
cusolverDnHandle_t cusolver_dn_handle_{nullptr};
bool dev_id_init_;
uint32_t cur_dev_id_;
};

@ -87,7 +87,7 @@ from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, Popul
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight,
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient)
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky)
from .sparse_ops import SparseToDense
__all__ = [

@ -607,3 +607,34 @@ class UpdateThorGradient(PrimitiveWithInfer):
validator.check_tensor_type_same({'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype},
[mstype.float32], self.name)
return x2_dtype
class Cholesky(PrimitiveWithInfer):
"""
Inner API for resnet50 THOR GPU backend
"""
@prim_attr_register
def __init__(self, split_dim=0):
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
self.split_dim = split_dim
self.add_prim_attr('split_dim', self.split_dim)
def infer_shape(self, x1_shape):
if self.split_dim != 0:
assert len(x1_shape) == 2
height = x1_shape[0]
width = x1_shape[1]
assert height == width
if height <= self.split_dim:
out_shape = [1, height, width]
else:
batch = height // self.split_dim
if height != batch * self.split_dim:
batch += 1
out_shape = [batch, self.split_dim, self.split_dim]
else:
out_shape = x1_shape
return out_shape
def infer_dtype(self, x1_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name)
return x1_dtype

Loading…
Cancel
Save