!7140 UnsortedSegmentSum null check

Merge pull request !7140 from chenweifeng/unsorted-null-checkout
pull/7140/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 2fb6b0cc5d

@ -27,7 +27,8 @@ namespace kernel {
template <typename T, typename S>
class UnsortedSegmentSumGpuKernel : public GpuKernel {
public:
UnsortedSegmentSumGpuKernel() : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1) {}
UnsortedSegmentSumGpuKernel()
: input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1), is_null_input_(false) {}
~UnsortedSegmentSumGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -36,6 +37,9 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
@ -50,6 +54,12 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shapes);
if (is_null_input_) {
MS_LOG(WARNING) << "UnsortedSegmentSum input is null";
InitSizeLists();
return true;
}
auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
@ -83,6 +93,7 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
size_t input_dim1_;
size_t output_dim0_;
size_t output_dim1_;
bool is_null_input_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

Loading…
Cancel
Save