From efdfbee9df61bbd501f81a4bbb429cfbd03edc6d Mon Sep 17 00:00:00 2001 From: VectorSL Date: Mon, 1 Jun 2020 21:34:06 +0800 Subject: [PATCH] gpu slice support null output --- .../ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h index 81910b5091..7f71e548ad 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h @@ -27,7 +27,8 @@ namespace kernel { template class SliceGpuFwdKernel : public GpuKernel { public: - SliceGpuFwdKernel() : is_strided_slice_(false), input_size_(0), output_size_(0), workspace_size_(0) {} + SliceGpuFwdKernel() + : is_strided_slice_(false), is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0) {} ~SliceGpuFwdKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } @@ -35,6 +36,9 @@ class SliceGpuFwdKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 0); if (is_strided_slice_) { @@ -79,7 +83,11 @@ class SliceGpuFwdKernel : public GpuKernel { if (size_[i] < 0) { size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; } - if (size_[i] == 0) { + if (begin_[i] == size_[i] && is_strided_slice_) { + MS_LOG(WARNING) << "Output is null."; + is_null_input_ = true; + } + if (size_[i] == 0 && strides_[i] > 0) { size_[i] = begin_[i] + 1; } } @@ -143,6 +151,7 @@ class SliceGpuFwdKernel : public GpuKernel { std::vector workspace_size_list_; bool is_strided_slice_; + bool is_null_input_; size_t input_size_; size_t output_size_; size_t workspace_size_;