From 447a45dbe702fa498d4508e4959a034eabd7d2cc Mon Sep 17 00:00:00 2001 From: VectorSL Date: Fri, 9 Oct 2020 09:34:53 +0800 Subject: [PATCH] change gpu kernel shape to size_t --- .../gpu/arrays/argmaxwithvalue_gpu_kernel.h | 6 +- .../gpu/arrays/array_reduce_gpu_kernel.h | 18 +- .../gpu/arrays/broadcast_to_gpu_kernel.h | 4 +- .../gpu/arrays/slice_gpu_kernel.h | 4 +- .../gpu/arrays/slice_grad_gpu_kernel.h | 13 +- .../gpu/arrays/strided_slice_gpu_kernel.h | 4 +- .../arrays/strided_slice_grad_gpu_kernel.h | 13 +- .../gpu/arrays/transpose_gpu_kernel.h | 15 +- .../gpu/cuda_impl/argmaxwithvalue_impl.cu | 32 +-- .../gpu/cuda_impl/argmaxwithvalue_impl.cuh | 2 +- .../gpu/cuda_impl/broadcast_impl.cu | 135 ++++++------ .../gpu/cuda_impl/broadcast_impl.cuh | 15 +- .../gpu/cuda_impl/slice_impl.cu | 193 +++++++++--------- .../gpu/cuda_impl/slice_impl.cuh | 15 +- .../gpu/cuda_impl/transpose_impl.cu | 35 ++-- .../gpu/cuda_impl/transpose_impl.cuh | 4 +- .../backend/kernel_compiler/gpu/gpu_kernel.h | 10 +- .../gpu/math/broadcast_gpu_kernel.h | 6 +- .../gpu/math/broadcast_grad_gpu_kernel.h | 12 +- .../gpu/nn/activation_gpu_kernel.h | 16 +- .../gpu/nn/activation_grad_kernel.h | 16 +- .../gpu/nn/softmax_gpu_kernel.h | 22 +- .../gpu/nn/softmax_grad_gpu_kernel.h | 22 +- 23 files changed, 325 insertions(+), 287 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h index a6cb342268..6ba4cef6f2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h @@ -86,9 +86,9 @@ class ArgmaxWithValueGpuKernel : public GpuKernel { std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; - int bound_; - int outerSize_; - int innerSize_; + size_t bound_; + size_t outerSize_; + size_t innerSize_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h index 8a273965ea..16307cfa41 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h @@ -180,15 +180,16 @@ class ArrayReduceGpuKernel : public GpuKernel { return; } void InferInAndOutDesc(const std::vector &input_shape, const std::vector &output_shape) { - std::vector inputA; + std::vector inputA; std::vector outputC_shape = output_shape; const int split_dim = 4; if (input_shape.size() <= split_dim) { ShapeNdTo4d(input_shape, &inputA); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, - inputA[0], inputA[1], inputA[2], inputA[3]), - "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, SizeToInt(inputA[0]), + SizeToInt(inputA[1]), SizeToInt(inputA[2]), SizeToInt(inputA[3])), + "cudnnSetTensor4dDescriptor failed"); } else { CudnnSetTensorNdDescriptor(input_shape, inputA_descriptor_, data_type_); for (auto dim : input_shape) { @@ -216,7 +217,7 @@ class ArrayReduceGpuKernel : public GpuKernel { return; } - std::vector outputC; + std::vector outputC; if (!keep_dims_) { for (auto i : axis_) { (void)(outputC_shape.insert(outputC_shape.begin() + i, 1)); @@ -225,9 +226,10 @@ class ArrayReduceGpuKernel : public GpuKernel { if (outputC_shape.size() <= split_dim) { ShapeNdTo4d(outputC_shape, &outputC); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, - outputC[0], outputC[1], outputC[2], outputC[3]), - "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, SizeToInt(outputC[0]), + SizeToInt(outputC[1]), SizeToInt(outputC[2]), SizeToInt(outputC[3])), + "cudnnSetTensor4dDescriptor failed"); } else { CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_); for (auto dim : outputC_shape) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h index 280879b81c..db78bd75e6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h @@ -71,8 +71,8 @@ class BroadcastToGpuKernel : public GpuKernel { } private: - int input_shape_[4] = {1, 1, 1, 1}; - int output_shape_[4] = {1, 1, 1, 1}; + size_t input_shape_[4] = {1, 1, 1, 1}; + size_t output_shape_[4] = {1, 1, 1, 1}; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h index 8bb1448bd2..ea385ab1eb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h @@ -61,7 +61,7 @@ class SliceGpuFwdKernel : public GpuKernel { (void)size_.insert(size_.begin(), 1); } - input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); + input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T); auto out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); output_size_ = sizeof(T); @@ -118,7 +118,7 @@ class SliceGpuFwdKernel : public GpuKernel { } std::vector begin_; std::vector size_; - std::vector input_shape_; + std::vector input_shape_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h index 45566c9d69..c1827b44b7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h @@ -50,7 +50,10 @@ class SliceGradGpuKernel : public GpuKernel { auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); if (kernel_name == "StridedSliceGrad") { is_strided_slice_ = true; - input_shape_ = GetAttr>(kernel_node, "shapex"); + auto shapex = GetAttr>(kernel_node, "shapex"); + for (auto x : shapex) { + input_shape_.push_back(IntToSize(x)); + } for (auto i = input_shape_.size(); i < 4; i++) { (void)input_shape_.insert(input_shape_.begin(), 1); } @@ -69,11 +72,11 @@ class SliceGradGpuKernel : public GpuKernel { ShapeNdTo4d(dy_shape, &dy_shape_); begin_ = GetAttr>(kernel_node, "begin"); DealParam(); - input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); + input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T); output_size_ = sizeof(T); for (auto x : dy_shape_) { - output_size_ = output_size_ * IntToSize(x); + output_size_ = output_size_ * x; } InitSizeLists(); return true; @@ -125,8 +128,8 @@ class SliceGradGpuKernel : public GpuKernel { std::vector begin_; std::vector size_; std::vector strides_; - std::vector input_shape_; - std::vector dy_shape_; + std::vector input_shape_; + std::vector dy_shape_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h index 6d5e506782..3587f7e4e8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h @@ -72,7 +72,7 @@ class StridedSliceGpuKernel : public GpuKernel { } input_size_list_.push_back(size); - int size1 = sizeof(T); + size_t size1 = sizeof(T); for (size_t i = 0; i < MAX_DIMS; i++) { size1 *= output_shape_[i]; } @@ -188,7 +188,7 @@ class StridedSliceGpuKernel : public GpuKernel { std::vector end_; std::vector strides_; std::vector input_shape_; - std::vector output_shape_; + std::vector output_shape_; int null_output_; std::vector input_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h index 737dcdb3e3..6d092d18f5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h @@ -50,7 +50,10 @@ class StridedSliceGradGpuKernel : public GpuKernel { return true; } bool Init(const CNodePtr &kernel_node) override { - input_shape_ = GetAttr>(kernel_node, "shapex"); + auto shapex = GetAttr>(kernel_node, "shapex"); + for (auto x : shapex) { + input_shape_.push_back(IntToSize(x)); + } if (input_shape_.size() > MAX_DIMS) { MS_LOG(ERROR) << "StridedSliceGrad support support dims less than " << input_shape_.size(); return false; @@ -66,13 +69,13 @@ class StridedSliceGradGpuKernel : public GpuKernel { protected: void InitSizeLists() override { - int size = sizeof(T); + size_t size = sizeof(T); for (size_t i = 0; i < MAX_DIMS; i++) { size *= output_shape_[i]; } input_size_list_.push_back(size); - int size1 = sizeof(T); + size_t size1 = sizeof(T); for (size_t i = 0; i < MAX_DIMS; i++) { size1 *= input_shape_[i]; } @@ -187,8 +190,8 @@ class StridedSliceGradGpuKernel : public GpuKernel { std::vector begin_; std::vector end_; std::vector strides_; - std::vector input_shape_; - std::vector output_shape_; + std::vector input_shape_; + std::vector output_shape_; int null_output_; std::vector input_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h index f2ca5c3296..bf40bb6475 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h @@ -37,17 +37,16 @@ class TransposeGpuFwdKernel : public GpuKernel { const std::vector &outputs, void *stream_ptr) override { T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 0); - int *input_shape = GetDeviceAddress(workspace, 0); - int *input_axis = GetDeviceAddress(workspace, 1); + size_t *input_shape = GetDeviceAddress(workspace, 0); + size_t *input_axis = GetDeviceAddress(workspace, 1); CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), "cudaMemcpyAsync input_shape failed"); CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), "cudaMemcpyAsync input_axis failed"); - int size = SizeToInt(input_size_ / sizeof(T)); - CalTranspose(size, input, input_shape, input_axis, SizeToInt(shape_size_), output, - reinterpret_cast(stream_ptr)); + size_t size = input_size_ / sizeof(T); + CalTranspose(size, input, input_shape, input_axis, shape_size_, output, reinterpret_cast(stream_ptr)); return true; } @@ -88,15 +87,15 @@ class TransposeGpuFwdKernel : public GpuKernel { void InitSizeLists() override { input_size_list_.push_back(input_size_); output_size_list_.push_back(output_size_); - workspace_size_ = shape_size_ * sizeof(int); + workspace_size_ = shape_size_ * sizeof(size_t); workspace_size_list_.push_back(workspace_size_); workspace_size_list_.push_back(workspace_size_); return; } private: - std::vector input_shape_; - std::vector input_axis_; + std::vector input_shape_; + std::vector input_axis_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu index 66a73aca50..0c4f0198f1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu @@ -18,14 +18,16 @@ #include "runtime/device/gpu/cuda_common.h" #include "include/cuda_fp16.h" template -__global__ void ArgmaxWithValue(const T *input, const int bound, int outerSize, int innerSize, S *index, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outerSize * innerSize; pos += gridDim.x * blockDim.x) { - int x = pos / innerSize % outerSize; - int y = pos % innerSize; +__global__ void ArgmaxWithValue(const T *input, const size_t bound, size_t outerSize, + size_t innerSize, S *index, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outerSize * innerSize; + pos += gridDim.x * blockDim.x) { + size_t x = pos / innerSize % outerSize; + size_t y = pos % innerSize; S idx = 0; - int InputOffset = x * bound * innerSize + 0 * innerSize + y; + size_t InputOffset = x * bound * innerSize + 0 * innerSize + y; T maxData = input[InputOffset]; - for (int i = 0; i < bound; i++) { + for (size_t i = 0; i < bound; i++) { InputOffset = x * bound * innerSize + i * innerSize + y; auto inputData = input[InputOffset]; idx = inputData > maxData ? i : idx; @@ -38,14 +40,16 @@ __global__ void ArgmaxWithValue(const T *input, const int bound, int outerSize, } template -void CalArgmaxWithValue(const T *input, const int bound_, const int outerSize_, const int innerSize_, S *index, - T *output, cudaStream_t cuda_stream) { - ArgmaxWithValue<<>>(input, bound_, outerSize_, innerSize_, index, - output); +void CalArgmaxWithValue(const T *input, const size_t bound_, const size_t outerSize_, const size_t innerSize_, + S *index, T *output, cudaStream_t cuda_stream) { + ArgmaxWithValue<<>>(input, bound_, outerSize_, innerSize_, + index, output); return; } -template void CalArgmaxWithValue(const float *input, const int bound_, const int outerSize_, - const int innerSize_, int *index, float *output, cudaStream_t cuda_stream); -template void CalArgmaxWithValue(const half *input, const int bound_, const int outerSize_, - const int innerSize_, int *index, half *output, cudaStream_t cuda_stream); +template void CalArgmaxWithValue(const float *input, const size_t bound_, const size_t outerSize_, + const size_t innerSize_, int *index, float *output, + cudaStream_t cuda_stream); +template void CalArgmaxWithValue(const half *input, const size_t bound_, const size_t outerSize_, + const size_t innerSize_, int *index, half *output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh index 67c061a966..9bdcab3eec 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh @@ -17,6 +17,6 @@ #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ template -void CalArgmaxWithValue(const T *input, const int bound_, const int outerSize_, const int innerSize_, S *index, +void CalArgmaxWithValue(const T *input, const size_t bound_, const size_t outerSize_, const size_t innerSize_, S *index, T *output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index 798794025b..ff109e886a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -219,31 +219,33 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int * cudaStream_t stream); // Broadcast comparation -__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } +__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } template -__global__ void BroadcastCmpKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5, - const int l6, const int r0, const int r1, const int r2, const int r3, const int r4, - const int r5, const int r6, const int d0, const int d1, const int d2, const int d3, - const int d4, const int d5, const int d6, const T *x0, const T *x1, bool *y) { +__global__ void BroadcastCmpKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, + const size_t l4, const size_t l5, const size_t l6, const size_t r0, + const size_t r1, const size_t r2, const size_t r3, const size_t r4, + const size_t r5, const size_t r6, const size_t d0, const size_t d1, + const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const size_t d6, const T *x0, const T *x1, bool *y) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; pos += blockDim.x * gridDim.x) { - int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0; - int j = pos / (d2 * d3 * d4 * d5 * d6) % d1; - int k = pos / (d3 * d4 * d5 * d6) % d2; - int l = pos / (d4 * d5 * d6) % d3; - int m = pos / (d5 * d6) % d4; - int n = pos / d6 % d5; - int o = pos % d6; - - int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6; + size_t i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0; + size_t j = pos / (d2 * d3 * d4 * d5 * d6) % d1; + size_t k = pos / (d3 * d4 * d5 * d6) % d2; + size_t l = pos / (d4 * d5 * d6) % d3; + size_t m = pos / (d5 * d6) % d4; + size_t n = pos / d6 % d5; + size_t o = pos % d6; + + size_t l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6; l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6; l_index += Index(k, l2) * l3 * l4 * l5 * l6; l_index += Index(l, l3) * l4 * l5 * l6; l_index += Index(m, l4) * l5 * l6; l_index += Index(n, l5) * l6; l_index += Index(o, l6); - int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6; + size_t r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6; r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6; r_index += Index(k, r2) * r3 * r4 * r5 * r6; r_index += Index(l, r3) * r4 * r5 * r6; @@ -255,9 +257,10 @@ __global__ void BroadcastCmpKernel(const int l0, const int l1, const int l2, con } template -void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, - enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream) { - int size = 1; +void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const T *x0, + const T *x1, bool *y, cudaStream_t stream) { + size_t size = 1; for (auto d : y_dims) { size *= d; } @@ -278,40 +281,42 @@ void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_di } } -template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, - const std::vector &y_dims, enum BroadcastOpType op, const float *x0, const float *x1, +template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const float *x0, const float *x1, bool *y, cudaStream_t stream); -template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, - const std::vector &y_dims, enum BroadcastOpType op, const half *x0, const half *x1, +template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const half *x0, const half *x1, bool *y, cudaStream_t stream); -template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, - const std::vector &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, +template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, bool *y, cudaStream_t stream); // Broadcast Arithmetic template -__global__ void BroadcastArithKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5, - const int l6, const int r0, const int r1, const int r2, const int r3, const int r4, - const int r5, const int r6, const int d0, const int d1, const int d2, const int d3, - const int d4, const int d5, const int d6, const T *x0, const T *x1, T *y) { +__global__ void BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, + const size_t l4, const size_t l5, const size_t l6, const size_t r0, + const size_t r1, const size_t r2, const size_t r3, const size_t r4, + const size_t r5, const size_t r6, const size_t d0, const size_t d1, + const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const size_t d6, const T *x0, const T *x1, T *y) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; pos += blockDim.x * gridDim.x) { - int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0; - int j = pos / (d2 * d3 * d4 * d5 * d6) % d1; - int k = pos / (d3 * d4 * d5 * d6) % d2; - int l = pos / (d4 * d5 * d6) % d3; - int m = pos / (d5 * d6) % d4; - int n = pos / d6 % d5; - int o = pos % d6; - - int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6; + size_t i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0; + size_t j = pos / (d2 * d3 * d4 * d5 * d6) % d1; + size_t k = pos / (d3 * d4 * d5 * d6) % d2; + size_t l = pos / (d4 * d5 * d6) % d3; + size_t m = pos / (d5 * d6) % d4; + size_t n = pos / d6 % d5; + size_t o = pos % d6; + + size_t l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6; l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6; l_index += Index(k, l2) * l3 * l4 * l5 * l6; l_index += Index(l, l3) * l4 * l5 * l6; l_index += Index(m, l4) * l5 * l6; l_index += Index(n, l5) * l6; l_index += Index(o, l6); - int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6; + size_t r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6; r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6; r_index += Index(k, r2) * r3 * r4 * r5 * r6; r_index += Index(l, r3) * r4 * r5 * r6; @@ -323,9 +328,10 @@ __global__ void BroadcastArithKernel(const int l0, const int l1, const int l2, c } template -void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, - enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) { - int size = 1; +void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const T *x0, + const T *x1, T *y, cudaStream_t stream) { + size_t size = 1; for (auto d : y_dims) { size *= d; } @@ -385,41 +391,44 @@ void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_ } } -template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, - const std::vector &y_dims, enum BroadcastOpType op, const float *x0, const float *x1, - float *y, cudaStream_t stream); -template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, - const std::vector &y_dims, enum BroadcastOpType op, const half *x0, const half *x1, +template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const float *x0, + const float *x1, float *y, cudaStream_t stream); +template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const half *x0, const half *x1, half *y, cudaStream_t stream); -template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, - const std::vector &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, +template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, int *y, cudaStream_t stream); // BroadcastTo template -__global__ void BroadcastToKernel(const int i0, const int i1, const int i2, const int i3, const int o0, const int o1, - const int o2, const int o3, const T *input_addr, T *output_addr) { +__global__ void BroadcastToKernel(const size_t i0, const size_t i1, const size_t i2, const size_t i3, const size_t o0, + const size_t o1, const size_t o2, const size_t o3, const T *input_addr, + T *output_addr) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < o0 * o1 * o2 * o3; pos += blockDim.x * gridDim.x) { - int i = pos / (o1 * o2 * o3) % o0; - int j = pos / (o2 * o3) % o1; - int k = pos / o3 % o2; - int l = pos % o3; + size_t i = pos / (o1 * o2 * o3) % o0; + size_t j = pos / (o2 * o3) % o1; + size_t k = pos / o3 % o2; + size_t l = pos % o3; - int input_idx = Index(i, i0) * i1 * i2 * i3 + Index(j, i1) * i2 * i3 + Index(k, i2) * i3 + Index(l, i3); + size_t input_idx = Index(i, i0) * i1 * i2 * i3 + Index(j, i1) * i2 * i3 + Index(k, i2) * i3 + Index(l, i3); output_addr[pos] = input_addr[input_idx]; } } template -void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, - const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream) { - int nums = o0 * o1 * o2 * o3; +void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, + const size_t &o1, const size_t &o2, const size_t &o3, const T *input_addr, + T *output_addr, cudaStream_t stream) { + size_t nums = o0 * o1 * o2 * o3; BroadcastToKernel<<>>(i0, i1, i2, i3, o0, o1, o2, o3, input_addr, output_addr); } -template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, - const int &o2, const int &o3, const float *input_addr, float *output_addr, - cudaStream_t stream); -template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, - const int &o2, const int &o3, const half *input_addr, half *output_addr, cudaStream_t stream); +template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, + const size_t &o1, const size_t &o2, const size_t &o3, const float *input_addr, + float *output_addr, cudaStream_t stream); +template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, + const size_t &o1, const size_t &o2, const size_t &o3, const half *input_addr, + half *output_addr, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh index 9f0a5ba984..09f0992e50 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -43,14 +43,17 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream); template -void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, - enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream); +void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, bool *y, + cudaStream_t stream); template -void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, - enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream); +void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, T *y, + cudaStream_t stream); template -void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, - const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream); +void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, + const size_t &o1, const size_t &o2, const size_t &o3, const T *input_addr, T *output_addr, + cudaStream_t stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu index 3b68941080..62ca154c18 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu @@ -21,16 +21,17 @@ #include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" template -__global__ void Slice4D(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, - const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, +__global__ void Slice4D(const size_t s1, const size_t s2, const size_t s3, const size_t s4, + const size_t l1, const size_t l2, const size_t l3, const size_t l4, + const size_t d1, const size_t d2, const size_t d3, const size_t d4, const T *input, T *output) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) { - int i = pos / (l2 * l3 * l4) % l1; - int j = pos / (l3 * l4) % l2; - int k = pos / l4 % l3; - int o = pos % l4; + size_t i = pos / (l2 * l3 * l4) % l1; + size_t j = pos / (l3 * l4) % l2; + size_t k = pos / l4 % l3; + size_t o = pos % l4; - int offset = (i + s1) * (d2 * d3 * d4) + (j + s2) * (d3 * d4) + (k + s3) * d4 + (o + s4); + size_t offset = (i + s1) * (d2 * d3 * d4) + (j + s2) * (d3 * d4) + (k + s3) * d4 + (o + s4); output[pos] = input[offset]; } } @@ -56,18 +57,19 @@ void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaSt return; } template -void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, const int l3, - const int l4, const int d1, const int d2, const int d3, const int d4, const T *input, T *output, - cudaStream_t stream) { +void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, + const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, + const size_t d3, const size_t d4, const T *input, T *output, cudaStream_t stream) { Slice4D<<>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4, input, output); } template -void CalSliceGrad(const size_t input_size, const T *dy, const std::vector in_shape, const std::vector begin, - const std::vector size, T *output, cudaStream_t cuda_stream) { - int block = in_shape[1] * in_shape[2] * in_shape[3]; - int map = in_shape[2] * in_shape[3]; - int w = in_shape[3]; +void CalSliceGrad(const size_t input_size, const T *dy, const std::vector in_shape, + const std::vector begin, const std::vector size, T *output, + cudaStream_t cuda_stream) { + size_t block = in_shape[1] * in_shape[2] * in_shape[3]; + size_t map = in_shape[2] * in_shape[3]; + size_t w = in_shape[3]; int length = size[3]; int p = 0; for (int i = begin[0]; i < size[0] + begin[0]; i++) { @@ -82,23 +84,24 @@ void CalSliceGrad(const size_t input_size, const T *dy, const std::vector i } template -__global__ void StridedSliceKernel(const int b0, const int b1, const int b2, const int b3, const int b4, - const int b5, const int b6, const int s0, const int s1, const int s2, - const int s3, const int s4, const int s5, const int s6, const int i0, - const int i1, const int i2, const int i3, const int i4, const int i5, - const int i6, const int o0, const int o1, const int o2, const int o3, - const int o4, const int o5, const int o6, const T *input_addr, T *output_addr) { +__global__ void StridedSliceKernel(const size_t b0, const size_t b1, const size_t b2, const size_t b3, const size_t b4, + const size_t b5, const size_t b6, const size_t s0, const size_t s1, const size_t s2, + const size_t s3, const size_t s4, const size_t s5, const size_t s6, const size_t i0, + const size_t i1, const size_t i2, const size_t i3, const size_t i4, const size_t i5, + const size_t i6, const size_t o0, const size_t o1, const size_t o2, const size_t o3, + const size_t o4, const size_t o5, const size_t o6, + const T *input_addr, T *output_addr) { int output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6; for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) { - int i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0; - int j = pos / (o2 * o3 * o4 * o5 * o6) % o1; - int k = pos / (o3 * o4 * o5 * o6) % o2; - int l = pos / (o4 * o5 * o6) % o3; - int m = pos / (o5 * o6) % o4; - int n = pos / (o6) % o5; - int o = pos % o6; - - int input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \ + size_t i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0; + size_t j = pos / (o2 * o3 * o4 * o5 * o6) % o1; + size_t k = pos / (o3 * o4 * o5 * o6) % o2; + size_t l = pos / (o4 * o5 * o6) % o3; + size_t m = pos / (o5 * o6) % o4; + size_t n = pos / (o6) % o5; + size_t o = pos % o6; + + size_t input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \ + (k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 \ + (n * s5 + b5) * i6 + (o * s6 + b6); output_addr[pos] = input_addr[input_idx]; @@ -107,10 +110,10 @@ __global__ void StridedSliceKernel(const int b0, const int b1, const int b2, con template void StridedSlice(const std::vector &input_shape, const std::vector &begin, - const std::vector &strides, const std::vector &output_shape, const T *input, T *output, + const std::vector &strides, const std::vector &output_shape, const T *input, T *output, cudaStream_t cuda_stream) { - int size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] \ - * output_shape[4] * output_shape[5] * output_shape[6]; + size_t size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] \ + * output_shape[4] * output_shape[5] * output_shape[6]; StridedSliceKernel<<>>( begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6], strides[0], strides[1], strides[2], strides[3], strides[4], strides[5], strides[6], @@ -120,23 +123,25 @@ void StridedSlice(const std::vector &input_shape, const std::vector } template -__global__ void StridedSliceGradKernel(const int b0, const int b1, const int b2, const int b3, const int b4, - const int b5, const int b6, const int s0, const int s1, const int s2, - const int s3, const int s4, const int s5, const int s6, const int i0, - const int i1, const int i2, const int i3, const int i4, const int i5, - const int i6, const int o0, const int o1, const int o2, const int o3, - const int o4, const int o5, const int o6, const T *dy, T *dx) { +__global__ void StridedSliceGradKernel(const size_t b0, const size_t b1, const size_t b2, const size_t b3, + const size_t b4, const size_t b5, const size_t b6, const size_t s0, + const size_t s1, const size_t s2, const size_t s3, const size_t s4, + const size_t s5, const size_t s6, const size_t i0, const size_t i1, + const size_t i2, const size_t i3, const size_t i4, const size_t i5, + const size_t i6, const size_t o0, const size_t o1, const size_t o2, + const size_t o3, const size_t o4, const size_t o5, const size_t o6, + const T *dy, T *dx) { int output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6; for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) { - int i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0; - int j = pos / (o2 * o3 * o4 * o5 * o6) % o1; - int k = pos / (o3 * o4 * o5 * o6) % o2; - int l = pos / (o4 * o5 * o6) % o3; - int m = pos / (o5 * o6) % o4; - int n = pos / (o6) % o5; - int o = pos % o6; - - int input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \ + size_t i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0; + size_t j = pos / (o2 * o3 * o4 * o5 * o6) % o1; + size_t k = pos / (o3 * o4 * o5 * o6) % o2; + size_t l = pos / (o4 * o5 * o6) % o3; + size_t m = pos / (o5 * o6) % o4; + size_t n = pos / (o6) % o5; + size_t o = pos % o6; + + size_t input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \ + (k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 \ + (n * s5 + b5) * i6 + (o * s6 + b6); dx[input_idx] = dy[pos]; @@ -145,9 +150,10 @@ __global__ void StridedSliceGradKernel(const int b0, const int b1, const int b2, } template -void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, - const std::vector &dx_shape, const T *dy, T *dx, cudaStream_t cuda_stream) { - int size = dy_shape[0] * dy_shape[1] * dy_shape[2] * dy_shape[3] * dy_shape[4] * dy_shape[5] * dy_shape[6]; +void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, + const T *dy, T *dx, cudaStream_t cuda_stream) { + size_t size = dy_shape[0] * dy_shape[1] * dy_shape[2] * dy_shape[3] * dy_shape[4] * dy_shape[5] * dy_shape[6]; StridedSliceGradKernel<<>>( begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6], strides[0], strides[1], strides[2], strides[3], strides[4], strides[5], strides[6], @@ -157,88 +163,89 @@ void StridedSliceGrad(const std::vector &dy_shape, const std::vector & } template void FillDeviceArray(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, - const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, - const float *input, float *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const float *dy, const std::vector in_shape, +template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, + const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, + const size_t d3, const size_t d4, const float *input, float *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const float *dy, const std::vector in_shape, const std::vector begin, const std::vector size, float *output, cudaStream_t cuda_stream); template void FillDeviceArray(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, - const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, - const half *input, half *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const half *dy, const std::vector in_shape, +template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, + const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, + const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const half *dy, const std::vector in_shape, const std::vector begin, const std::vector size, half *output, cudaStream_t cuda_stream); template void FillDeviceArray(const size_t input_size, int *addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, - const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, - const int *input, int *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const int *dy, const std::vector in_shape, +template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, + const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, + const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const int *dy, const std::vector in_shape, const std::vector begin, const std::vector size, int *output, cudaStream_t cuda_stream); template void FillDeviceArray(const size_t input_size, short *addr, const float value, cudaStream_t cuda_stream); // NOLINT -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, - const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, - const short *input, short *output, cudaStream_t stream); // NOLINT -template void CalSliceGrad(const size_t input_size, const short *dy, const std::vector in_shape, // NOLINT +template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, + const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, + const size_t d3, const size_t d4, const short *input, short *output, cudaStream_t stream); // NOLINT +template void CalSliceGrad(const size_t input_size, const short *dy, const std::vector in_shape, // NOLINT const std::vector begin, const std::vector size, short *output, // NOLINT cudaStream_t cuda_stream); template void FillDeviceArray(const size_t input_size, unsigned char *addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, - const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, - const unsigned char *input, unsigned char *output, cudaStream_t stream); +template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, + const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, + const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output, + cudaStream_t stream); template void CalSliceGrad(const size_t input_size, const unsigned char *dy, - const std::vector in_shape, const std::vector begin, + const std::vector in_shape, const std::vector begin, const std::vector size, unsigned char *output, cudaStream_t cuda_stream); template void FillDeviceArray(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, - const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, - const bool *input, bool *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const bool *dy, const std::vector in_shape, +template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, + const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, + const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const bool *dy, const std::vector in_shape, const std::vector begin, const std::vector size, bool *output, cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, - const std::vector &strides, const std::vector &output_shape, const float *input, + const std::vector &strides, const std::vector &output_shape, const float *input, float *output, cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, - const std::vector &strides, const std::vector &output_shape, const half *input, + const std::vector &strides, const std::vector &output_shape, const half *input, half *output, cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, - const std::vector &strides, const std::vector &output_shape, const int *input, + const std::vector &strides, const std::vector &output_shape, const int *input, int *output, cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, - const std::vector &strides, const std::vector &output_shape, + const std::vector &strides, const std::vector &output_shape, const short *input, short *output, cudaStream_t cuda_stream); // NOLINT template void StridedSlice(const std::vector &input_shape, const std::vector &begin, - const std::vector &strides, const std::vector &output_shape, + const std::vector &strides, const std::vector &output_shape, const unsigned char *input, unsigned char *output, cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, - const std::vector &strides, const std::vector &output_shape, const bool *input, + const std::vector &strides, const std::vector &output_shape, const bool *input, bool *output, cudaStream_t cuda_stream); -template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, - const std::vector &strides, const std::vector &dx_shape, const float *dy, +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const float *dy, float *dx, cudaStream_t cuda_stream); -template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, - const std::vector &strides, const std::vector &dx_shape, const half *dy, +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const half *dy, half *dx, cudaStream_t cuda_stream); -template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, - const std::vector &strides, const std::vector &dx_shape, const int *dy, +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const int *dy, int *dx, cudaStream_t cuda_stream); -template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, - const std::vector &strides, const std::vector &dx_shape, const short *dy, // NOLINT +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const short *dy, // NOLINT short *dx, cudaStream_t cuda_stream); // NOLINT -template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, - const std::vector &strides, const std::vector &dx_shape, +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const unsigned char *dy, unsigned char *dx, cudaStream_t cuda_stream); -template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, - const std::vector &strides, const std::vector &dx_shape, const bool *dy, +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const bool *dy, bool *dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh index 70b013174e..7aa45ec9b2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh @@ -22,19 +22,20 @@ #include "runtime/device/gpu/cuda_common.h" template -void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, const int l3, - const int l4, const int d1, const int d2, const int d3, const int d4, const T *input, T *output, - cudaStream_t stream); +void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, + const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, + const T *input, T *output, cudaStream_t stream); template -void CalSliceGrad(const size_t input_size, const T *input, const std::vector in_shape, +void CalSliceGrad(const size_t input_size, const T *input, const std::vector in_shape, const std::vector begin, const std::vector size, T *output, cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, - const std::vector &strides, const std::vector &output_shape, const T *input, T *output, + const std::vector &strides, const std::vector &output_shape, const T *input, T *output, cudaStream_t cuda_stream); template -void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, - const std::vector &dx_shape, const T *dy, T *dx, cudaStream_t cuda_stream); +void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const T *dy, T *dx, + cudaStream_t cuda_stream); template void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu index fe38188930..3a0bea4d51 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu @@ -20,19 +20,19 @@ #include "runtime/device/gpu/cuda_common.h" template -__global__ void Transpose(const int size, const T* input, const int* input_shape, const int* input_axis, - const int shape_size, T* output) { - int pos_size; - int temp_pos; - int newpos; - int newpos_size; - int pos_array[TRANSPOSE_MAX_DIMENSION]; +__global__ void Transpose(const size_t size, const T* input, const size_t* input_shape, + const size_t* input_axis, const size_t shape_size, T* output) { + size_t pos_size; + size_t temp_pos; + size_t newpos; + size_t newpos_size; + size_t pos_array[TRANSPOSE_MAX_DIMENSION]; // for example 4-D: pos = posArray[0] * input_shape[1] * input_shape[2] * input_shape[3] + // posArray[1] * input_shape[2] * input_shape[3] + // posArray[2] * input_shape[3] + // posArray[3] - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { temp_pos = pos; pos_size = size / input_shape[0]; pos_array[0] = temp_pos / pos_size; @@ -54,16 +54,19 @@ __global__ void Transpose(const int size, const T* input, const int* input_shape return; } template -void CalTranspose(const int size, const T* input, const int* input_shape, const int* input_axis, const int shape_size, - T* output, cudaStream_t cuda_stream) { +void CalTranspose(const size_t size, const T* input, const size_t* input_shape, const size_t* input_axis, + const size_t shape_size, T* output, cudaStream_t cuda_stream) { Transpose<<>>(size, input, input_shape, input_axis, shape_size, output); return; } -template void CalTranspose(const int size, const float* input, const int* input_shape, const int* input_axis, - const int shape_size, float* output, cudaStream_t cuda_stream); -template void CalTranspose(const int size, const half* input, const int* input_shape, const int* input_axis, - const int shape_size, half* output, cudaStream_t cuda_stream); -template void CalTranspose(const int size, const int* input, const int* input_shape, const int* input_axis, - const int shape_size, int* output, cudaStream_t cuda_stream); +template void CalTranspose(const size_t size, const float* input, const size_t* input_shape, + const size_t* input_axis, const size_t shape_size, float* output, + cudaStream_t cuda_stream); +template void CalTranspose(const size_t size, const half* input, const size_t* input_shape, + const size_t* input_axis, const size_t shape_size, half* output, + cudaStream_t cuda_stream); +template void CalTranspose(const size_t size, const int* input, const size_t* input_shape, + const size_t* input_axis, const size_t shape_size, int* output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh index dbf7d140eb..f1fea0a83c 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh @@ -19,7 +19,7 @@ #define TRANSPOSE_MAX_DIMENSION 100 template -void CalTranspose(const int size, const T* input, const int* input_shape, const int* input_axis, const int shape_size, - T* output, cudaStream_t cuda_stream); +void CalTranspose(const size_t size, const T *input, const size_t *input_shape, const size_t *input_axis, + const size_t shape_size, T *output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRANSPOSE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h index fe363145d0..4d74c4d988 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h @@ -78,14 +78,14 @@ class GpuKernel : public KernelMod { return GetValue(attr); } // expand Nd Shape to 4d (N in [0,4]) - void ShapeNdTo4d(const std::vector &src, std::vector *dst) { + void ShapeNdTo4d(const std::vector &src, std::vector *dst) { if (src.size() > 4) { MS_EXCEPTION(ValueError) << src.size() << "-D data is not supported!"; } - dst->push_back(src.size() < 4 ? 1 : SizeToInt(src[src.size() - 4])); - dst->push_back(src.size() < 3 ? 1 : SizeToInt(src[src.size() - 3])); - dst->push_back(src.size() < 2 ? 1 : SizeToInt(src[src.size() - 2])); - dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1])); + dst->push_back(src.size() < 4 ? 1 : src[src.size() - 4]); + dst->push_back(src.size() < 3 ? 1 : src[src.size() - 3]); + dst->push_back(src.size() < 2 ? 1 : src[src.size() - 2]); + dst->push_back(src.size() == 0 ? 1 : src[src.size() - 1]); } int AxisTransform(const std::string &origin_data_format, const std::string &cal_format, int axis) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index b739969a3f..b06c1801c4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -162,9 +162,9 @@ class BroadcastOpGpuKernel : public GpuKernel { int input1_num_; int input2_num_; int output_num_; - std::vector lhs_shape_; - std::vector rhs_shape_; - std::vector output_shape_; + std::vector lhs_shape_; + std::vector rhs_shape_; + std::vector output_shape_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h index 0d7d9bb2cf..200335af88 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h @@ -140,12 +140,12 @@ class BroadcastOpGradGpuKernel : public GpuKernel { BroadcastGradOpType op_type_; bool need_broadcast_; - int input1_num_; - int input2_num_; - int output_num_; - int x1_shape_[4] = {1, 1, 1, 1}; - int x2_shape_[4] = {1, 1, 1, 1}; - int dy_shape_[4] = {1, 1, 1, 1}; + size_t input1_num_; + size_t input2_num_; + size_t output_num_; + size_t x1_shape_[4] = {1, 1, 1, 1}; + size_t x2_shape_[4] = {1, 1, 1, 1}; + size_t dy_shape_[4] = {1, 1, 1, 1}; bool grad_x_; bool grad_y_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h index b49e3b86a1..0853973872 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h @@ -82,7 +82,7 @@ class ActivationGpuFwdKernel : public GpuKernel { InitSizeLists(); return true; } - std::vector shape; + std::vector shape; double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 6.0 : 0.0; CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, coef), "cudnnSetActivationDescriptor failed"); @@ -91,13 +91,15 @@ class ActivationGpuFwdKernel : public GpuKernel { if (input_shape.size() <= split_dim) { ShapeNdTo4d(input_shape, &shape); if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, - shape[0], shape[3], shape[1], shape[2]), - "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, SizeToInt(shape[0]), + SizeToInt(shape[3]), SizeToInt(shape[1]), SizeToInt(shape[2])), + "cudnnSetTensor4dDescriptor failed"); } else { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - shape[0], shape[1], shape[2], shape[3]), - "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(shape[0]), + SizeToInt(shape[1]), SizeToInt(shape[2]), SizeToInt(shape[3])), + "cudnnSetTensor4dDescriptor failed"); } } else { CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h index c6fd1a0921..86709bd76a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h @@ -89,7 +89,7 @@ class ActivationGradGpuKernel : public GpuKernel { InitSizeLists(); return true; } - std::vector shape; + std::vector shape; double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 5.999999 : 0.0; CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, coef), "SetActivationDescriptor failed"); @@ -98,13 +98,15 @@ class ActivationGradGpuKernel : public GpuKernel { if (input_shape.size() <= split_dim) { ShapeNdTo4d(input_shape, &shape); if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, - shape[0], shape[3], shape[1], shape[2]), - "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, SizeToInt(shape[0]), + SizeToInt(shape[3]), SizeToInt(shape[1]), SizeToInt(shape[2])), + "cudnnSetTensor4dDescriptor failed"); } else { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - shape[0], shape[1], shape[2], shape[3]), - "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(shape[0]), + SizeToInt(shape[1]), SizeToInt(shape[2]), SizeToInt(shape[3])), + "cudnnSetTensor4dDescriptor failed"); } } else { CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h index 9369ba0a55..20163bbd86 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h @@ -68,9 +68,9 @@ class SoftmaxGpuKernel : public GpuKernel { } else { T *transpose_input_addr = GetDeviceAddress(workspace, 0); T *transpose_output_addr = GetDeviceAddress(workspace, 1); - int *input_shape = GetDeviceAddress(workspace, 2); - int *transpose_shape = GetDeviceAddress(workspace, 3); - int *transpose_axis = GetDeviceAddress(workspace, 4); + size_t *input_shape = GetDeviceAddress(workspace, 2); + size_t *transpose_shape = GetDeviceAddress(workspace, 3); + size_t *transpose_axis = GetDeviceAddress(workspace, 4); CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), "cudaMemcpyAsync input_shape failed"); @@ -80,7 +80,7 @@ class SoftmaxGpuKernel : public GpuKernel { CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), "cudaMemcpyAsync input_axis failed"); - int size = SizeToInt(input_size_ / sizeof(T)); + size_t size = input_size_ / sizeof(T); CalTranspose(size, input_addr, input_shape, transpose_axis, shape_size_, transpose_input_addr, reinterpret_cast(stream_ptr)); CHECK_CUDNN_RET_WITH_EXCEPT( @@ -113,7 +113,7 @@ class SoftmaxGpuKernel : public GpuKernel { InitSizeLists(); return true; } - shape_size_ = SizeToInt(input_shape.size()); + shape_size_ = input_shape.size(); auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); if (kernel_name == "LogSoftmax") { algo_ = CUDNN_SOFTMAX_LOG; @@ -171,7 +171,7 @@ class SoftmaxGpuKernel : public GpuKernel { void InitSizeByAxis2D(const std::vector &input_shape, const int &axis) { axis_ = axis; if (axis_ < 0) { - axis_ += shape_size_; + axis_ += SizeToInt(shape_size_); } if (axis_ == 1) { batch_size_ = input_shape[0]; @@ -193,7 +193,7 @@ class SoftmaxGpuKernel : public GpuKernel { width_ = 1; input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; output_size_ = input_size_; - workspace_size_ = IntToSize(shape_size_) * sizeof(int); + workspace_size_ = shape_size_ * sizeof(size_t); } void InitSizeByAxisLastDim(const std::vector &input_shape, const int &axis) { @@ -235,11 +235,11 @@ class SoftmaxGpuKernel : public GpuKernel { std::vector output_size_list_; std::vector workspace_size_list_; - std::vector input_shape_; - std::vector transpose_shape_; - std::vector transpose_axis_; + std::vector input_shape_; + std::vector transpose_shape_; + std::vector transpose_axis_; int axis_; - int shape_size_; + size_t shape_size_; size_t batch_size_; size_t channel_size_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h index 7bea9b3569..5fc24b8ed8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h @@ -62,9 +62,9 @@ class SoftmaxGradGpuKernel : public GpuKernel { T *transpose_y_addr = GetDeviceAddress(workspace, 0); T *transpose_dy_addr = GetDeviceAddress(workspace, 1); T *transpose_dx_addr = GetDeviceAddress(workspace, 2); - int *input_shape = GetDeviceAddress(workspace, 3); - int *transpose_shape = GetDeviceAddress(workspace, 4); - int *transpose_axis = GetDeviceAddress(workspace, 5); + size_t *input_shape = GetDeviceAddress(workspace, 3); + size_t *transpose_shape = GetDeviceAddress(workspace, 4); + size_t *transpose_axis = GetDeviceAddress(workspace, 5); const float alpha = 1; const float beta = 0; @@ -82,7 +82,7 @@ class SoftmaxGradGpuKernel : public GpuKernel { CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), "cudaMemcpyAsync input_axis failed"); - int size = SizeToInt(input_size_ / sizeof(T)); + size_t size = input_size_ / sizeof(T); CalTranspose(size, y_addr, input_shape, transpose_axis, shape_size_, transpose_y_addr, reinterpret_cast(stream_ptr)); CalTranspose(size, dy_addr, input_shape, transpose_axis, shape_size_, transpose_dy_addr, @@ -116,7 +116,7 @@ class SoftmaxGradGpuKernel : public GpuKernel { InitSizeLists(); return true; } - shape_size_ = SizeToInt(input_shape.size()); + shape_size_ = input_shape.size(); if (shape_size_ != 2) { MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but softmax grad only supports 2-D inputs."; } @@ -164,7 +164,7 @@ class SoftmaxGradGpuKernel : public GpuKernel { void InitSizeByAxis(const std::vector input_shape, const int axis) { axis_ = axis; if (axis_ < 0) { - axis_ += shape_size_; + axis_ += SizeToInt(shape_size_); } if (axis_ == 1) { batch_size_ = input_shape[0]; @@ -186,7 +186,7 @@ class SoftmaxGradGpuKernel : public GpuKernel { width_ = 1; input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; output_size_ = input_size_; - workspace_size_ = IntToSize(shape_size_) * sizeof(int); + workspace_size_ = shape_size_ * sizeof(size_t); } cudnnHandle_t cudnn_handle_; @@ -202,11 +202,11 @@ class SoftmaxGradGpuKernel : public GpuKernel { std::vector output_size_list_; std::vector workspace_size_list_; - std::vector input_shape_; - std::vector transpose_shape_; - std::vector transpose_axis_; + std::vector input_shape_; + std::vector transpose_shape_; + std::vector transpose_axis_; int axis_; - int shape_size_; + size_t shape_size_; size_t batch_size_; size_t channel_size_;