From d7bb5cd058a9ec5c6625867fbc9ff86abfc48c73 Mon Sep 17 00:00:00 2001 From: Jonathan Yan Date: Thu, 3 Dec 2020 10:09:07 -0500 Subject: [PATCH] Fix CI Alarms --- .../gpu/arrays/gatherv2_gpu_kernel.cc | 7 + .../arrays/unsorted_segment_sum_gpu_kernel.cc | 1 - .../gpu/data/dataset_iterator_kernel.cc | 16 +- .../gpu/data/dataset_iterator_kernel.h | 1 + .../kernel_compiler/gpu/gpu_kernel_factory.cc | 26 +- .../kernel_compiler/gpu/gpu_kernel_factory.h | 1 + .../gpu/math/cholesky_trsm_solve_gpu_kernel.h | 278 +++++++++--------- .../gpu/math/update_thor_gradient.h | 4 +- .../gpu/nccl/nccl_collective_gpu_kernel.h | 90 ++++-- .../gpu/nn/ctcloss_gpu_kernel.h | 159 ++++++---- .../gpu/nn/l2normalize_gpu_kernel.h | 23 +- .../gpu/nn/l2normalize_grad_gpu_kernel.h | 58 ++-- .../kernel_compiler/gpu/nn/pad_gpu_kernel.h | 23 +- .../gpu/nn/pooling_grad_gpu_kernel.h | 38 ++- 14 files changed, 430 insertions(+), 295 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc index aedf4698cb..1b6c585271 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc @@ -22,10 +22,12 @@ MS_REG_GPU_KERNEL_TWO( GatherV2, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), GatherV2GpuFwdKernel, float, int) + MS_REG_GPU_KERNEL_TWO( GatherV2, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), GatherV2GpuFwdKernel, half, int) + MS_REG_GPU_KERNEL_TWO(GatherV2, KernelAttr() .AddInputAttr(kNumberTypeFloat32) @@ -33,6 +35,7 @@ MS_REG_GPU_KERNEL_TWO(GatherV2, .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat32), GatherV2GpuFwdKernel, float, int) + MS_REG_GPU_KERNEL_TWO(GatherV2, KernelAttr() .AddInputAttr(kNumberTypeFloat16) @@ -40,14 +43,17 @@ MS_REG_GPU_KERNEL_TWO(GatherV2, .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat16), GatherV2GpuFwdKernel, half, int) + MS_REG_GPU_KERNEL_TWO( SparseGatherV2, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), GatherV2GpuFwdKernel, float, int) + MS_REG_GPU_KERNEL_TWO( SparseGatherV2, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), GatherV2GpuFwdKernel, half, int) + MS_REG_GPU_KERNEL_TWO(SparseGatherV2, KernelAttr() .AddInputAttr(kNumberTypeFloat32) @@ -55,6 +61,7 @@ MS_REG_GPU_KERNEL_TWO(SparseGatherV2, .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat32), GatherV2GpuFwdKernel, float, int) + MS_REG_GPU_KERNEL_TWO(SparseGatherV2, KernelAttr() .AddInputAttr(kNumberTypeFloat16) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc index 56c029b2d6..03f5096c34 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc @@ -18,7 +18,6 @@ namespace mindspore { namespace kernel { - MS_REG_GPU_KERNEL_TWO( UnsortedSegmentSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc index 4b9749a86a..59269b7703 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc @@ -85,10 +85,7 @@ bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) { void DatasetIteratorKernel::InitSizeLists() { return; } -bool DatasetIteratorKernel::Launch(const std::vector &, const std::vector &, - const std::vector &outputs, void *stream) { - void *addr = nullptr; - size_t len = 0; +bool DatasetIteratorKernel::ReadDevice(void **addr, size_t *len) { uint64_t start_time_stamp = 0; uint32_t queue_size = 0; @@ -98,7 +95,7 @@ bool DatasetIteratorKernel::Launch(const std::vector &, const std::v start_time_stamp = profiling_op_->GetTimeStamp(); queue_size = GpuBufferMgr::GetInstance().Size(handle_); } - auto ret = GpuBufferMgr::GetInstance().Front(handle_, &addr, &len); + auto ret = GpuBufferMgr::GetInstance().Front(handle_, addr, len); if (ret == device::SUCCESS) { if (profiling_enable_) { uint64_t end_time_stamp = profiling_op_->GetTimeStamp(); @@ -129,7 +126,16 @@ bool DatasetIteratorKernel::Launch(const std::vector &, const std::v MS_LOG(ERROR) << "Get data failed, errcode " << ret; return false; } + return true; +} +bool DatasetIteratorKernel::Launch(const std::vector &, const std::vector &, + const std::vector &outputs, void *stream) { + void *addr = nullptr; + size_t len = 0; + if (!ReadDevice(&addr, &len)) { + return false; + } if (total_bytes_ != len) { MS_LOG(ERROR) << "Dataset front error. read: " << len << ", expect: " << total_bytes_ << ", "; return false; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h index 2aa62880f7..331c4a695d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h @@ -43,6 +43,7 @@ class DatasetIteratorKernel : public GpuKernel { void InitSizeLists() override; private: + bool ReadDevice(void **addr, size_t *len); std::string queue_name_; unsigned int handle_; size_t total_bytes_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc index 7e941114dc..93ac6f9f90 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc @@ -107,10 +107,23 @@ bool GpuKernelFactory::ReducePrecision( return GpuKernelFactory::SearchRegistered(kernel_name, builder->Build()); } +void GpuKernelFactory::CheckSM(const KernelBuildInfo *kernel_info, const size_t &input_index) { + const int major_sm = GET_MAJOR_SM; + const bool check_sm = mindspore::device::gpu::CudaCommon::GetInstance().check_sm(); + if (check_sm && major_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { + if (major_sm < MINIUM_SM) { + MS_LOG(EXCEPTION) << "Half precision ops can be used on Devices which computing capacity is >= " << MINIUM_SM + << ", but the current device's computing capacity is " << major_sm; + } + MS_LOG(WARNING) << "It is recommended to use devices with a computing capacity >= " << RECOMMEND_SM + << ", but the current device's computing capacity is " << major_sm; + mindspore::device::gpu::CudaCommon::GetInstance().set_check_sm(false); + } +} + std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info) { auto iter = map_kernel_name_to_creater_.find(kernel_name); - const int marjor_sm = GET_MAJOR_SM; if (map_kernel_name_to_creater_.end() == iter) { MS_LOG(INFO) << "Not registered GPU kernel: op[" << kernel_name << "]!"; return std::make_pair(false, 0); @@ -127,16 +140,7 @@ std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string & auto attr_size = (&(iter->second))->at(attr_index).first.GetInputSize(); // data type matching check of all input parameters of kernel for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) { - const bool check_sm = mindspore::device::gpu::CudaCommon::GetInstance().check_sm(); - if (check_sm && marjor_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { - if (marjor_sm < MINIUM_SM) { - MS_LOG(EXCEPTION) << "Half precision ops can be used on Devices which computing capacity is >= " << MINIUM_SM - << ", but the current device's computing capacity is " << marjor_sm; - } - MS_LOG(WARNING) << "It is recommended to use devices with a computing capacity >= " << RECOMMEND_SM - << ", but the current device's computing capacity is " << marjor_sm; - mindspore::device::gpu::CudaCommon::GetInstance().set_check_sm(false); - } + GpuKernelFactory::CheckSM(kernel_info, input_index); if (kernel_info->GetInputDeviceType(input_index) != (iter->second)[attr_index].first.GetInputAttr(input_index % attr_size).first) { flag = false; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h index 6bb64d5ba5..711817b52c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h @@ -57,6 +57,7 @@ class GpuKernelFactory { GpuKernelFactory &operator=(const GpuKernelFactory &); std::pair GpuKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info); + void CheckSM(const KernelBuildInfo *kernel_info, const size_t &input_index); bool CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, std::vector> *iter_second, size_t attr_index); // map to maintain kernel and creater, KernelAttr object and creater must be registered as a pair. diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_trsm_solve_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_trsm_solve_gpu_kernel.h index 4f100ead4b..8a6883d0de 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_trsm_solve_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_trsm_solve_gpu_kernel.h @@ -44,61 +44,9 @@ class CholeskyTrsmGpuKernel : public GpuKernel { return true; } if (!use_split_matrix) { - auto input1_addr = GetDeviceAddress(inputs, 0); - auto output_addr = GetDeviceAddress(outputs, 0); - auto d_array_addr = GetDeviceAddress(workspace, 0); - auto d_identity_addr = GetDeviceAddress(workspace, 1); - auto d_info_array_addr = GetDeviceAddress(workspace, 2); - for (size_t i = 0; i < batch_; i++) { - h_array[i] = input1_addr + i * lda_ * m_; - h_identity[i] = output_addr + i * ldb_ * m_; - CHECK_CUDA_RET_WITH_ERROR( - cudaMemcpyAsync(output_addr + i * ldb_ * m_, h_identity_data.data(), sizeof(T) * ldb_ * m_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cuda memcopy Fail"); - } - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cuda memcopy Fail"); - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cuda memcopy Fail"); - CHECK_CUSOLVER_RET_WITH_EXCEPT( - cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_), - "cusolver cholesky batched Fail"); - float alpha = 1; - CHECK_CUBLAS_RET_WITH_EXCEPT( - cublasStrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, m_, m_, &alpha, - d_array_addr, lda_, d_identity_addr, ldb_, batch_), - "cublas trsm batched Fail"); + LaunchNonSplitMatrix(inputs, workspace, outputs, stream_ptr); } else { - auto input1_addr = GetDeviceAddress(inputs, 0); - auto output_addr = GetDeviceAddress(outputs, 0); - auto d_array_addr = GetDeviceAddress(workspace, 0); - auto d_identity_addr = GetDeviceAddress(workspace, 1); - auto d_info_array_addr = GetDeviceAddress(workspace, 2); - auto d_batch_input_addr = GetDeviceAddress(workspace, 3); - for (size_t i = 0; i < batch_; i++) { - h_array[i] = d_batch_input_addr + i * lda_ * m_; - h_identity[i] = output_addr + i * ldb_ * m_; - } - Identity(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast(stream_ptr)); - MatrixSplit(batch_ * split_dim * split_dim, split_dim, width, input1_addr, d_batch_input_addr, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cuda memcopy Fail"); - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cuda memcopy Fail"); - CHECK_CUSOLVER_RET_WITH_EXCEPT( - cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_), - "cusolver cholesky batched Fail"); - float alpha = 1; - CHECK_CUBLAS_RET_WITH_EXCEPT( - cublasStrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, m_, m_, &alpha, - d_array_addr, lda_, d_identity_addr, ldb_, batch_), - "cublas trsm batched Fail"); + LaunchSplitMatrix(inputs, workspace, outputs, stream_ptr); } return true; } @@ -108,92 +56,17 @@ class CholeskyTrsmGpuKernel : public GpuKernel { auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); split_dim = static_cast(GetAttr(kernel_node, "split_dim")); if (split_dim == 0) { - use_split_matrix = false; - if (in_shape.size() == 2) { - batch_ = 1; - if (in_shape[0] != in_shape[1]) { - MS_LOG(ERROR) << "CholeskyTrsm need square matrix as input."; - } - } else if (in_shape.size() == 3) { - batch_ = SizeToInt(in_shape[0]); - if (in_shape[1] != in_shape[2]) { - MS_LOG(ERROR) << "CholeskyTrsm need square matrix as input."; - } - } else { - MS_LOG(ERROR) << "Input Only support Rank 2 OR 3"; - } - - m_ = SizeToInt(in_shape[1]); - lda_ = m_; - ldb_ = m_; - h_array.resize(batch_); - h_identity.resize(batch_); - h_identity_data.resize(m_ * m_); - for (size_t i = 0; i < m_; i++) { - for (size_t j = 0; j < m_; j++) { - if (i == j) { - h_identity_data[i * m_ + j] = 1; - } else { - h_identity_data[i * m_ + j] = 0; - } - } - } - InitSizeLists(); + InitDim0(kernel_node, in_shape); } else { if (in_shape.size() != 2) { MS_LOG(ERROR) << "CholeskyTrsm Split Matrix Need Input Rank as 2."; } - height = in_shape[0]; - width = in_shape[1]; - if (height != width) { + if (in_shape[0] != in_shape[1]) { MS_LOG(ERROR) << "CholeskyTrsm Split Matrix Need Square Matrix as Input."; } - if (SizeToInt(height) <= split_dim) { - use_split_matrix = false; - batch_ = 1; - m_ = SizeToInt(in_shape[1]); - lda_ = m_; - ldb_ = m_; - h_array.resize(batch_); - h_identity.resize(batch_); - h_identity_data.resize(m_ * m_); - for (size_t i = 0; i < m_; i++) { - for (size_t j = 0; j < m_; j++) { - if (i == j) { - h_identity_data[i * m_ + j] = 1; - } else { - h_identity_data[i * m_ + j] = 0; - } - } - } - InitSizeLists(); - } else { - use_split_matrix = true; - int batch = SizeToInt(in_shape[1]) / split_dim; - res_dim = in_shape[1] - batch * split_dim; - if (res_dim == 0) { - batch_ = batch; - } else { - batch_ = batch + 1; - } - m_ = split_dim; - lda_ = m_; - ldb_ = m_; - h_array.resize(batch_); - h_identity.resize(batch_); - h_identity_data.resize(m_ * m_); - for (size_t i = 0; i < m_; i++) { - for (size_t j = 0; j < m_; j++) { - if (i == j) { - h_identity_data[i * m_ + j] = 1; - } else { - h_identity_data[i * m_ + j] = 0; - } - } - } - InitSizeLists(); - } + InitDimOthers(kernel_node, in_shape); } + InitSizeLists(); return true; } @@ -229,6 +102,145 @@ class CholeskyTrsmGpuKernel : public GpuKernel { } private: + void LaunchNonSplitMatrix(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + auto input1_addr = GetDeviceAddress(inputs, 0); + auto output_addr = GetDeviceAddress(outputs, 0); + auto d_array_addr = GetDeviceAddress(workspace, 0); + auto d_identity_addr = GetDeviceAddress(workspace, 1); + auto d_info_array_addr = GetDeviceAddress(workspace, 2); + for (size_t i = 0; i < batch_; i++) { + h_array[i] = input1_addr + i * lda_ * m_; + h_identity[i] = output_addr + i * ldb_ * m_; + CHECK_CUDA_RET_WITH_ERROR( + cudaMemcpyAsync(output_addr + i * ldb_ * m_, h_identity_data.data(), sizeof(T) * ldb_ * m_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + } + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUSOLVER_RET_WITH_EXCEPT( + cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_), + "cusolver cholesky batched Fail"); + float alpha = 1; + CHECK_CUBLAS_RET_WITH_EXCEPT( + cublasStrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, m_, m_, &alpha, + d_array_addr, lda_, d_identity_addr, ldb_, batch_), + "cublas trsm batched Fail"); + } + void LaunchSplitMatrix(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + auto input1_addr = GetDeviceAddress(inputs, 0); + auto output_addr = GetDeviceAddress(outputs, 0); + auto d_array_addr = GetDeviceAddress(workspace, 0); + auto d_identity_addr = GetDeviceAddress(workspace, 1); + auto d_info_array_addr = GetDeviceAddress(workspace, 2); + auto d_batch_input_addr = GetDeviceAddress(workspace, 3); + for (size_t i = 0; i < batch_; i++) { + h_array[i] = d_batch_input_addr + i * lda_ * m_; + h_identity[i] = output_addr + i * ldb_ * m_; + } + Identity(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast(stream_ptr)); + MatrixSplit(batch_ * split_dim * split_dim, split_dim, width, input1_addr, d_batch_input_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUSOLVER_RET_WITH_EXCEPT( + cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_), + "cusolver cholesky batched Fail"); + float alpha = 1; + CHECK_CUBLAS_RET_WITH_EXCEPT( + cublasStrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, m_, m_, &alpha, + d_array_addr, lda_, d_identity_addr, ldb_, batch_), + "cublas trsm batched Fail"); + } + void InitDim0(const CNodePtr &kernel_node, const std::vector &in_shape) { + use_split_matrix = false; + if (in_shape.size() == 2) { + batch_ = 1; + if (in_shape[0] != in_shape[1]) { + MS_LOG(ERROR) << "CholeskyTrsm need square matrix as input."; + } + } else if (in_shape.size() == 3) { + batch_ = SizeToInt(in_shape[0]); + if (in_shape[1] != in_shape[2]) { + MS_LOG(ERROR) << "CholeskyTrsm need square matrix as input."; + } + } else { + MS_LOG(ERROR) << "Input Only support Rank 2 OR 3"; + } + + m_ = SizeToInt(in_shape[1]); + lda_ = m_; + ldb_ = m_; + h_array.resize(batch_); + h_identity.resize(batch_); + h_identity_data.resize(m_ * m_); + for (size_t i = 0; i < m_; i++) { + for (size_t j = 0; j < m_; j++) { + if (i == j) { + h_identity_data[i * m_ + j] = 1; + } else { + h_identity_data[i * m_ + j] = 0; + } + } + } + } + void InitDimOthers(const CNodePtr &kernel_node, const std::vector &in_shape) { + height = in_shape[0]; + width = in_shape[1]; + if (SizeToInt(height) <= split_dim) { + use_split_matrix = false; + batch_ = 1; + m_ = SizeToInt(in_shape[1]); + lda_ = m_; + ldb_ = m_; + h_array.resize(batch_); + h_identity.resize(batch_); + h_identity_data.resize(m_ * m_); + for (size_t i = 0; i < m_; i++) { + for (size_t j = 0; j < m_; j++) { + if (i == j) { + h_identity_data[i * m_ + j] = 1; + } else { + h_identity_data[i * m_ + j] = 0; + } + } + } + } else { + use_split_matrix = true; + int batch = SizeToInt(in_shape[1]) / split_dim; + res_dim = in_shape[1] - batch * split_dim; + if (res_dim == 0) { + batch_ = batch; + } else { + batch_ = batch + 1; + } + m_ = split_dim; + lda_ = m_; + ldb_ = m_; + h_array.resize(batch_); + h_identity.resize(batch_); + h_identity_data.resize(m_ * m_); + for (size_t i = 0; i < m_; i++) { + for (size_t j = 0; j < m_; j++) { + if (i == j) { + h_identity_data[i * m_ + j] = 1; + } else { + h_identity_data[i * m_ + j] = 0; + } + } + } + } + } size_t batch_; size_t m_; size_t lda_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.h index ff686d3666..30ce884e10 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.h @@ -159,8 +159,8 @@ class UpdateThorGradientGpuKernel : public GpuKernel { size_t output_size = gradient_size.ori_h * gradient_size.ori_w * unit_size; output_size_list_.push_back(output_size); - size_t workspace_size_ = 0; - workspace_size_ = gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * unit_size; + size_t workspace_size_ = + gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * unit_size; workspace_size_list_.push_back(workspace_size_); if (gradient_size.need_convert) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h index d04753e888..559fe59ea2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h @@ -56,51 +56,24 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs, void *stream_ptr) override { - T *input_addr = GetDeviceAddress(inputs, 0); - T *output_addr = GetDeviceAddress(outputs, 0); - - cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); switch (nccl_kernel_type_) { case NCCL_ALL_REDUCE: { - auto all_reduce_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "AllReduce")); - MS_EXCEPTION_IF_NULL(all_reduce_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), - nccl_data_type_, nccl_reduce_type_, stream, group_name_), - "ncclAllReduce failed"); + LaunchAllReduce(inputs, outputs, stream_ptr); break; } case NCCL_ALL_GATHER: { - auto all_gather_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "AllGather")); - MS_EXCEPTION_IF_NULL(all_gather_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT( - (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream, group_name_), - "ncclAllGather failed"); + LaunchAllGather(inputs, outputs, stream_ptr); break; } case NCCL_REDUCE_SCATTER: { - auto reduce_scatter_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "ReduceScatter")); - MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), - nccl_data_type_, nccl_reduce_type_, stream, group_name_), - "ncclReduceScatter failed"); + LaunchReduceScatter(inputs, outputs, stream_ptr); break; } case NCCL_BROADCAST: { - auto broadcast_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "Broadcast")); - MS_EXCEPTION_IF_NULL(broadcast_funcptr); - for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) { - input_addr = GetDeviceAddress(inputs, i); - output_addr = GetDeviceAddress(outputs, i); - CHECK_NCCL_RET_WITH_EXCEPT((*broadcast_funcptr)(input_addr, output_addr, output_size_list_[i] / sizeof(T), - nccl_data_type_, root_, stream, group_name_), - "ncclBroadcast failed"); - } + LaunchBroadcast(inputs, outputs, stream_ptr); break; } default: { @@ -153,6 +126,59 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { void InitSizeLists() override { return; } private: + void LaunchAllReduce(const std::vector &inputs, const std::vector &outputs, + void *stream_ptr) { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); + auto all_reduce_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "AllReduce")); + MS_EXCEPTION_IF_NULL(all_reduce_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_, + nccl_reduce_type_, stream, group_name_), + "ncclAllReduce failed"); + } + + void LaunchAllGather(const std::vector &inputs, const std::vector &outputs, + void *stream_ptr) { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); + auto all_gather_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "AllGather")); + MS_EXCEPTION_IF_NULL(all_gather_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT( + (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream, group_name_), + "ncclAllGather failed"); + } + + void LaunchReduceScatter(const std::vector &inputs, const std::vector &outputs, + void *stream_ptr) { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); + auto reduce_scatter_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "ReduceScatter")); + MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), + nccl_data_type_, nccl_reduce_type_, stream, group_name_), + "ncclReduceScatter failed"); + } + + void LaunchBroadcast(const std::vector &inputs, const std::vector &outputs, + void *stream_ptr) { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); + auto broadcast_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "Broadcast")); + MS_EXCEPTION_IF_NULL(broadcast_funcptr); + for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) { + input_addr = GetDeviceAddress(inputs, i); + output_addr = GetDeviceAddress(outputs, i); + CHECK_NCCL_RET_WITH_EXCEPT((*broadcast_funcptr)(input_addr, output_addr, output_size_list_[i] / sizeof(T), + nccl_data_type_, root_, stream, group_name_), + "ncclBroadcast failed"); + } + } + void InferCommType(const CNodePtr &kernel_node) { std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); auto iter = kNcclTypeMap.find(kernel_name); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h index 13506cb2bd..4b6f4bf60b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h @@ -44,25 +44,76 @@ class CtcLossGpuKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { + LaunchInit(inputs, workspace, outputs); + LaunchFirstHalf(inputs, workspace, outputs, stream_ptr); + LaunchSecondHalf(inputs, workspace, outputs, stream_ptr); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (probs_shape.size() != 3) { + MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support."; + } + probs_dims_[0] = probs_shape[0]; + probs_dims_[1] = probs_shape[1]; + probs_dims_[2] = probs_shape[2]; + auto indice_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + if (labels_dims.size() != 1) { + MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support."; + } + if (indice_dims.size() != 2) { + MS_LOG(EXCEPTION) << "labels indice dims: " << indice_dims.size() << " not support."; + } + label_size_ = sizeof(int); + for (auto i : labels_dims) { + label_size_ *= i; + } + label_indice_size_ = sizeof(int64_t); + for (auto i : indice_dims) { + label_indice_size_ *= i; + } + auto squence_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + squence_lengths_size_ = squence_length_dims[0] * sizeof(int); + preprocess_collapse_repeated_ = GetAttr(kernel_node, "preprocess_collapse_repeated"); + ctc_merge_repeated_ = GetAttr(kernel_node, "ctc_merge_repeated"); + ignore_longer_outputs_than_inputs_ = GetAttr(kernel_node, "ignore_longer_outputs_than_inputs"); + InitSizeLists(); + return true; + } + + protected: + void LaunchInit(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + probs = GetDeviceAddress(inputs, 0); + label_indices = GetDeviceAddress(inputs, 1); + label_values = GetDeviceAddress(inputs, 2); + sequence_length = GetDeviceAddress(inputs, 3); + costs = GetDeviceAddress(outputs, 0); + grads = GetDeviceAddress(outputs, 1); + softmax_probs = GetDeviceAddress(workspace, 0); + cum_labels_length = GetDeviceAddress(workspace, 1); + label_squence_length = GetDeviceAddress(workspace, 2); + label_value_sp = GetDeviceAddress(workspace, 3); + label_value_pcr = GetDeviceAddress(workspace, 4); + prob_num = GetDeviceAddress(workspace, 5); + precum_labels_length = GetDeviceAddress(workspace, 6); + max_labels_length = GetDeviceAddress(workspace, 7); + numclass = SizeToInt(probs_dims_[2]); + batch = SizeToInt(probs_dims_[1]); + max_time = SizeToInt(probs_dims_[0]); + max_sequence = 0; + max_labels_length_host = 0; + batch_label = 0; + label_value_with_blank = nullptr; + log_alpha_b = nullptr; + log_beta_b = nullptr; + } + + void LaunchFirstHalf(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { cudaStream_t stream = reinterpret_cast(stream_ptr); - const T *probs = GetDeviceAddress(inputs, 0); - const int64_t *label_indices = GetDeviceAddress(inputs, 1); - const int *label_values = GetDeviceAddress(inputs, 2); - const int *sequence_length = GetDeviceAddress(inputs, 3); - T *costs = GetDeviceAddress(outputs, 0); - T *grads = GetDeviceAddress(outputs, 1); - T *softmax_probs = GetDeviceAddress(workspace, 0); - int *cum_labels_length = GetDeviceAddress(workspace, 1); - int *label_squence_length = GetDeviceAddress(workspace, 2); - int *label_value_sp = GetDeviceAddress(workspace, 3); - int *label_value_pcr = GetDeviceAddress(workspace, 4); - T *prob_num = GetDeviceAddress(workspace, 5); - int *precum_labels_length = GetDeviceAddress(workspace, 6); - int *max_labels_length = GetDeviceAddress(workspace, 7); - int numclass = SizeToInt(probs_dims_[2]); - int batch = SizeToInt(probs_dims_[1]); - int max_time = SizeToInt(probs_dims_[0]); - int max_sequence = 0; CalculateMaxSequence(sequence_length, max_labels_length, batch, stream); CHECK_CUDA_RET_WITH_EXCEPT( cudaMemcpyAsync(&max_sequence, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream), @@ -73,11 +124,7 @@ class CtcLossGpuKernel : public GpuKernel { } InnerSoftMax(probs, softmax_probs, sequence_length, max_time, batch, numclass, stream); MemsetForWS(label_value_pcr, cum_labels_length, label_squence_length, costs, grads, stream); - int max_labels_length_host = 0; - int batch_label = 0; - int *label_value_with_blank = nullptr; - T *log_alpha_b = nullptr; - T *log_beta_b = nullptr; + CalculatePreLength(label_squence_length, precum_labels_length, cum_labels_length, max_labels_length, label_indices, batch, label_size_ / sizeof(int), stream); CHECK_CUDA_RET_WITH_EXCEPT( @@ -97,8 +144,14 @@ class CtcLossGpuKernel : public GpuKernel { cudaMemcpyAsync(&max_labels_length_host, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream), "cudaMemcpyAsync failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); + } + + void LaunchSecondHalf(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + cudaStream_t stream = reinterpret_cast(stream_ptr); int SOffSet = 2 * max_labels_length_host + 1; int log_prob_size = batch * SOffSet * max_time; + if (!ignore_longer_outputs_than_inputs_ && max_labels_length_host > max_time) { MS_LOG(EXCEPTION) << "output size is greater than input size."; } @@ -124,43 +177,8 @@ class CtcLossGpuKernel : public GpuKernel { ignore_longer_outputs_than_inputs_, stream); CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); FreeMem(label_value_with_blank, log_alpha_b, log_beta_b); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (probs_shape.size() != 3) { - MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support."; - } - probs_dims_[0] = probs_shape[0]; - probs_dims_[1] = probs_shape[1]; - probs_dims_[2] = probs_shape[2]; - auto indice_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - if (labels_dims.size() != 1) { - MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support."; - } - if (indice_dims.size() != 2) { - MS_LOG(EXCEPTION) << "labels indice dims: " << indice_dims.size() << " not support."; - } - label_size_ = sizeof(int); - for (auto i : labels_dims) { - label_size_ *= i; - } - label_indice_size_ = sizeof(int64_t); - for (auto i : indice_dims) { - label_indice_size_ *= i; - } - auto squence_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - squence_lengths_size_ = squence_length_dims[0] * sizeof(int); - preprocess_collapse_repeated_ = GetAttr(kernel_node, "preprocess_collapse_repeated"); - ctc_merge_repeated_ = GetAttr(kernel_node, "ctc_merge_repeated"); - ignore_longer_outputs_than_inputs_ = GetAttr(kernel_node, "ignore_longer_outputs_than_inputs"); - InitSizeLists(); - return true; } - protected: void InitSizeLists() override { input_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(T)); input_size_list_.push_back(label_indice_size_); @@ -226,6 +244,31 @@ class CtcLossGpuKernel : public GpuKernel { bool ctc_merge_repeated_; bool ignore_longer_outputs_than_inputs_; T kLogZero_ = -std::numeric_limits::infinity(); + + // Heap parameter + T *probs; + int64_t *label_indices; + int *label_values; + int *sequence_length; + T *costs; + T *grads; + T *softmax_probs; + int *cum_labels_length; + int *label_squence_length; + int *label_value_sp; + int *label_value_pcr; + T *prob_num; + int *precum_labels_length; + int *max_labels_length; + int numclass; + int batch; + int max_time; + int max_sequence; + int max_labels_length_host; + int batch_label; + int *label_value_with_blank; + T *log_alpha_b; + T *log_beta_b; }; // namespace kernel } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h index e1c3a5d325..65875f9f04 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h @@ -82,17 +82,11 @@ class L2NormalizeGpuKernel : public GpuKernel { return true; } + bool Init(const CNodePtr &kernel_node) override { InitResource(); data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but l2normalize op needs 1 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but l2normalize op needs 1 output."; + if (!CheckIONumber(kernel_node)) { return false; } int input_dim_length = SizeToInt(AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0).size()); @@ -176,6 +170,19 @@ class L2NormalizeGpuKernel : public GpuKernel { } private: + bool CheckIONumber(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but l2normalize op needs 1 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but l2normalize op needs 1 output."; + return false; + } + return true; + } void DestroyResource() noexcept { CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_), "cudnnDestroyReduceTensorDescriptor failed."); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h index 76253dc341..39e334f3b3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h @@ -104,20 +104,31 @@ class L2NormalizeGradGpuKernel : public GpuKernel { return true; } + + bool CheckInputShape(const std::vector &output_shape) { + for (auto &shape : input_shape_list_) { + if (output_shape != shape) { + MS_LOG(EXCEPTION) << "Input shape and output shape should be same!"; + } + } + is_null_input_ = CHECK_NULL_INPUT(input_shape_list_[0]); + if (is_null_input_) { + MS_LOG(WARNING) << "L2NormalizeGPUKernel input is null"; + InitSizeLists(); + return false; + } + if (input_shape_list_[0].size() > MAX_DIMS) { + MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7"; + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { InitResource(); data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != INPUT_SIZE) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but l2normalize op needs 3 inputs."; + if (!CheckIONumber(kernel_node)) { return false; } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but l2normalize op needs 1 output."; - return false; - } - int input_dim_length = SizeToInt(AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0).size()); int axis = static_cast(GetAttr(kernel_node, "axis")); axis_ = axis < 0 ? (axis + input_dim_length) : axis; @@ -128,10 +139,8 @@ class L2NormalizeGradGpuKernel : public GpuKernel { input_shape_list_.emplace_back(AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i)); } auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - for (auto &shape : input_shape_list_) { - if (output_shape != shape) { - MS_LOG(EXCEPTION) << "Input shape and output shape should be same!"; - } + if (!CheckInputShape(output_shape)) { + return true; } output_size_ = sizeof(T); @@ -139,16 +148,6 @@ class L2NormalizeGradGpuKernel : public GpuKernel { output_size_ *= dim; } - is_null_input_ = CHECK_NULL_INPUT(input_shape_list_[0]); - if (is_null_input_) { - MS_LOG(WARNING) << "L2NormalizeGPUKernel input is null"; - InitSizeLists(); - return true; - } - if (input_shape_list_[0].size() > MAX_DIMS) { - MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7"; - } - std::vector output_reduce_shape = output_shape; output_reduce_shape[axis_] = 1; @@ -173,6 +172,19 @@ class L2NormalizeGradGpuKernel : public GpuKernel { } protected: + bool CheckIONumber(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != INPUT_SIZE) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but l2normalize op needs 3 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but l2normalize op needs 1 output."; + return false; + } + return true; + } void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateReduceTensorDescriptor(&reduce_tensor_descriptor_), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h index 169bbd7e3f..84bd46a4ff 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h @@ -53,14 +53,7 @@ class PadGpuFwdKernel : public GpuKernel { } bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but Pad needs 1 input."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but Pad needs 1 output."; + if (!CheckIONumber(kernel_node)) { return false; } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); @@ -114,6 +107,20 @@ class PadGpuFwdKernel : public GpuKernel { } private: + bool CheckIONumber(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but Pad needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but Pad needs 1 output."; + return false; + } + return true; + } + size_t shape_size_; size_t temp; std::vector> paddings; // list of paddings (tuple of tuple in python) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h index a3a8456800..49a4142943 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h @@ -76,11 +76,9 @@ class PoolingGradGpuKernel : public GpuKernel { "cudnnPoolingBackward failed"); return true; } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - if (!CheckParam(kernel_node)) { - return false; - } + + bool InitShape(const CNodePtr &kernel_node, int *dimA, int *strideAin, int *dimAy, int *strideAiny, int *dimAdy, + int *strideAdy, int *dimAout, int *strideAout) { auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); auto input_mask = AnfAlgo::GetInputDeviceShape(kernel_node, 1); auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); @@ -95,9 +93,25 @@ class PoolingGradGpuKernel : public GpuKernel { if (is_null_input_) { MS_LOG(WARNING) << "PoolingGradGpuKernel input is null."; InitSizeLists(); - return true; + return false; } SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format); + SetDimA(input_shape, dimA, 4, data_format); + SetStrideA(input_shape, strideAin, 4, data_format); + SetDimA(input_mask, dimAy, 4, data_format); + SetStrideA(input_mask, strideAiny, 4, data_format); + SetDimA(dout_shape, dimAdy, 4, data_format); + SetStrideA(dout_shape, strideAdy, 4, data_format); + SetDimA(output_shape, dimAout, 4, data_format); + SetStrideA(output_shape, strideAout, 4, data_format); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } const int nbDims = 4; int dimA[4]; int strideAin[4]; @@ -107,14 +121,9 @@ class PoolingGradGpuKernel : public GpuKernel { int strideAdy[4]; int dimAout[4]; int strideAout[4]; - SetDimA(input_shape, dimA, 4, data_format); - SetStrideA(input_shape, strideAin, 4, data_format); - SetDimA(input_mask, dimAy, 4, data_format); - SetStrideA(input_mask, strideAiny, 4, data_format); - SetDimA(dout_shape, dimAdy, 4, data_format); - SetStrideA(dout_shape, strideAdy, 4, data_format); - SetDimA(output_shape, dimAout, 4, data_format); - SetStrideA(output_shape, strideAout, 4, data_format); + if (!InitShape(kernel_node, dimA, strideAin, dimAy, strideAiny, dimAdy, strideAdy, dimAout, strideAout)) { + return true; + } CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny), "cudnnSetTensor4dDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy), @@ -129,6 +138,7 @@ class PoolingGradGpuKernel : public GpuKernel { InitSizeLists(); return true; } + void DestroyResource() noexcept override { CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), "cudnnDestroyPoolingDescriptor failed");