From 3bef4e9b75f2f55f381fbdab6ce27c91a3a7f21c Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Thu, 28 Jan 2021 16:11:16 -0500 Subject: [PATCH] range better error initial commit --- .../gpu/arrays/dynamic_range_gpu_kernel.h | 45 +++++++++- .../gpu/cuda_impl/dynamic_range_impl.cu | 85 +++++++++++++------ .../gpu/cuda_impl/dynamic_range_impl.cuh | 17 +++- tests/st/ops/gpu/test_range_op.py | 32 ++++++- 4 files changed, 143 insertions(+), 36 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h index 3ee237f1c8..1be2cd92b7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,22 +44,58 @@ class DynamicRangeGpuKernel : public GpuKernel { T *range_delta = GetDeviceAddress(inputs, 2); T *output_device_address = GetDeviceAddress(outputs, 0); int64_t *output_shape_device_address = GetDeviceAddress(workspace, 0); + DynamicRangeErrorCode *error_code_device_address = GetDeviceAddress(workspace, 1); stream_ptr_ = stream_ptr; - CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address, - max_output_length_, reinterpret_cast(stream_ptr)); + CudaValidateInputAndInferShape(range_start, range_end, range_delta, output_shape_device_address, + error_code_device_address, max_output_length_, + reinterpret_cast(stream_ptr)); + + DynamicRangeErrorCode error_code = DynamicRangeErrorCode::kOk; + + CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_, + cudaMemcpyAsync(&error_code, error_code_device_address, sizeof(DynamicRangeErrorCode), + cudaMemcpyDeviceToHost, reinterpret_cast(stream_ptr)), + "Failed to copy error code to host."); + CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed"); // use workspace[0] for actual output shape, we know it must be 1d CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_, cudaMemcpyAsync(&output_shape_, output_shape_device_address, sizeof(int64_t), cudaMemcpyDeviceToHost, reinterpret_cast(stream_ptr)), - "Failed to copy gpu memory."); + "Failed to copy output_shape to host."); CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed"); + LogExceptionIfNotOk(error_code); + + CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address, + error_code_device_address, max_output_length_, reinterpret_cast(stream_ptr)); + return true; } + void LogExceptionIfNotOk(DynamicRangeErrorCode error_code) { + switch (error_code) { + case DynamicRangeErrorCode::kOk: + return; + case DynamicRangeErrorCode::kDeltaIsZero: + MS_LOG(EXCEPTION) << "gpu RangeOp input error: delta cannot be equal to zero"; + break; + case DynamicRangeErrorCode::kInvalidPositiveDelta: + MS_LOG(EXCEPTION) << "gpu RangeOp input error: delta cannot be positive when limit < start"; + break; + case DynamicRangeErrorCode::kInvalidNegativeDelta: + MS_LOG(EXCEPTION) << "gpu RangeOp input error: delta cannot be negative when limit > start"; + break; + case DynamicRangeErrorCode::kMaxSizeExceeded: + MS_LOG(EXCEPTION) << "gpu RangeOp memory error: the number of elements in the output exceeds maxlen"; + break; + default: + MS_LOG(EXCEPTION) << "gpu RangeOp unknown error"; + } + } + void PostExecute() override { // required synchronize for PostExecute CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), @@ -103,6 +139,7 @@ class DynamicRangeGpuKernel : public GpuKernel { // this op outputs a 1d tensor, size of one int64_t is enough space to hold the shape. workspace_size_list_.push_back(sizeof(int64_t)); + workspace_size_list_.push_back(sizeof(DynamicRangeErrorCode)); return; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cu index 51a9051dd2..13e42d4650 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,57 +20,90 @@ #include "runtime/device/gpu/cuda_common.h" template -__device__ void CheckInputs(const T &start, const T &end, const T &delta) { +__global__ void ValidateInputAndInferShape(const T *range_start, const T *range_end, const T *range_delta, + int64_t *output_shape, DynamicRangeErrorCode *error_code, + const int64_t max_output_size) { + T start = range_start[0]; + T end = range_end[0]; + T delta = range_delta[0]; + *error_code = DynamicRangeErrorCode::kOk; + if (delta == 0) { - asm("trap;"); + *error_code = DynamicRangeErrorCode::kDeltaIsZero; + return; } if (start < end && delta < 0) { - asm("trap;"); + *error_code = DynamicRangeErrorCode::kInvalidNegativeDelta; + return; } if (start > end && delta > 0) { - asm("trap;"); + *error_code = DynamicRangeErrorCode::kInvalidPositiveDelta; + return; + } + + if (*error_code == DynamicRangeErrorCode::kOk) { + int64_t real_output_shape = static_cast(ceil(static_cast(end - start) / delta)); + if (real_output_shape > max_output_size) { + *error_code = DynamicRangeErrorCode::kMaxSizeExceeded; + } + *output_shape = real_output_shape; } } template -__global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output, - int64_t *output_shape, const int64_t max_output_size) { +__global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape, + const int64_t max_output_size) { T start = range_start[0]; - T end = range_end[0]; T delta = range_delta[0]; - CheckInputs(start, end, delta); - - int64_t real_output_shape = static_cast(ceil(static_cast(end - start) / delta)); - if (real_output_shape > max_output_size) { - asm("trap;"); - } - *output_shape = real_output_shape; - size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; - for (; gt_id < real_output_shape; gt_id += blockDim.x * gridDim.x) { + for (; gt_id < *output_shape; gt_id += blockDim.x * gridDim.x) { output[gt_id] = gt_id * delta + start; } } +template +void CudaValidateInputAndInferShape(const T *range_start, const T *range_end, const T *range_delta, + int64_t *output_shape, DynamicRangeErrorCode *error_code, + const int64_t max_output_size, cudaStream_t cuda_stream) { + ValidateInputAndInferShape<<<1, 1, 0, cuda_stream>>>(range_start, range_end, range_delta, output_shape, error_code, + max_output_size); +} + template void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape, - const int64_t max_output_size, cudaStream_t cuda_stream) { + DynamicRangeErrorCode *error_code, const int64_t max_output_size, cudaStream_t cuda_stream) { Range<<>>(range_start, range_end, range_delta, output, output_shape, max_output_size); } -template void CalRange(const int *range_start, const int *range_end, const int *range_delta, int *output, - int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream); +template void CudaValidateInputAndInferShape(const int *range_start, const int *range_end, const int *range_delta, + int64_t *output_shape, DynamicRangeErrorCode *error_code, + const int64_t max_output_size, cudaStream_t cuda_stream); +template void CudaValidateInputAndInferShape(const int64_t *range_start, const int64_t *range_end, + const int64_t *range_delta, int64_t *output_shape, + DynamicRangeErrorCode *error_code, const int64_t max_output_size, + cudaStream_t cuda_stream); +template void CudaValidateInputAndInferShape(const float *range_start, const float *range_end, + const float *range_delta, int64_t *output_shape, + DynamicRangeErrorCode *error_code, const int64_t max_output_size, + cudaStream_t cuda_stream); +template void CudaValidateInputAndInferShape(const double *range_start, const double *range_end, + const double *range_delta, int64_t *output_shape, + DynamicRangeErrorCode *error_code, const int64_t max_output_size, + cudaStream_t cuda_stream); +template void CalRange(const int *range_start, const int *range_end, const int *range_delta, int *output, + int64_t *output_shape, DynamicRangeErrorCode *error_code, const int64_t max_output_size, + cudaStream_t cuda_stream); template void CalRange(const int64_t *range_start, const int64_t *range_end, const int64_t *range_delta, - int64_t *output, int64_t *output_shape, const int64_t max_output_size, - cudaStream_t cuda_stream); - + int64_t *output, int64_t *output_shape, DynamicRangeErrorCode *error_code, + const int64_t max_output_size, cudaStream_t cuda_stream); template void CalRange(const float *range_start, const float *range_end, const float *range_delta, float *output, - int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream); + int64_t *output_shape, DynamicRangeErrorCode *error_code, const int64_t max_output_size, + cudaStream_t cuda_stream); template void CalRange(const double *range_start, const double *range_end, const double *range_delta, - double *output, int64_t *output_shape, const int64_t max_output_size, - cudaStream_t cuda_stream); + double *output, int64_t *output_shape, DynamicRangeErrorCode *error_code, + const int64_t max_output_size, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh index 17b1fd8c0a..535e344303 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,8 +19,21 @@ #include +enum class DynamicRangeErrorCode { + kOk = 0, + kDeltaIsZero, + kInvalidPositiveDelta, + kInvalidNegativeDelta, + kMaxSizeExceeded +}; + +template +void CudaValidateInputAndInferShape(const T *range_start, const T *range_end, const T *range_delta, + int64_t *output_shape, DynamicRangeErrorCode *error_code, + const int64_t max_output_size, cudaStream_t cuda_stream); + template void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape, - const int64_t max_output_size, cudaStream_t cuda_stream); + DynamicRangeErrorCode *error_code, const int64_t max_output_size, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_ diff --git a/tests/st/ops/gpu/test_range_op.py b/tests/st/ops/gpu/test_range_op.py index de76b01525..102f45bfd9 100644 --- a/tests/st/ops/gpu/test_range_op.py +++ b/tests/st/ops/gpu/test_range_op.py @@ -22,12 +22,12 @@ from mindspore import Tensor from mindspore.ops import operations as P class RangeNet(nn.Cell): - def __init__(self): + def __init__(self, maxlen=10000): super(RangeNet, self).__init__() - self.range = P.Range() + self.range = P.Range(maxlen) - def construct(self, s, e, d): - return self.range(s, e, d) + def construct(self, start, limit, delta): + return self.range(start, limit, delta) @pytest.mark.level0 @@ -91,3 +91,27 @@ def test_range_invalid_max_output_length(): _ = P.Range(-1) _ = P.Range(None) _ = P.Range('5') + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_range_invalid_input(): + with pytest.raises(RuntimeError) as info: + range_net = RangeNet(3500) + _ = range_net(Tensor(0, mstype.int32), Tensor(5, mstype.int32), Tensor(0, mstype.int32)).asnumpy() + assert "delta cannot be equal to zero" in str(info.value) + + with pytest.raises(RuntimeError) as info: + range_net = RangeNet(2) + _ = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy() + assert "number of elements in the output exceeds maxlen" in str(info.value) + + with pytest.raises(RuntimeError) as info: + range_net = RangeNet(3500) + _ = range_net(Tensor(20, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy() + assert "delta cannot be positive when limit < start" in str(info.value) + + with pytest.raises(RuntimeError) as info: + range_net = RangeNet(3500) + _ = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(-4, mstype.int32)).asnumpy() + assert "delta cannot be negative when limit > start" in str(info.value)