From 3fdc3629af882520b948ba94e46e1a12df2c96b8 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Thu, 9 Jul 2020 11:16:36 +0800 Subject: [PATCH] Check attr exists before getting it in embeddinglookup cpu kernel --- .../ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc | 8 ++++++-- mindspore/ccsrc/utils/utils.h | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc index c8c2c667ad..f2fd7fc650 100644 --- a/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc @@ -36,7 +36,9 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { } output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); axis_ = 4 - input_shape_.size(); - reduce_scatter_flag_ = AnfAlgo::GetNodeAttr(kernel_node, "reduce_scatter_flag"); + if (AnfAlgo::HasNodeAttr(kAttrReduceScatterFlag, kernel_node)) { + reduce_scatter_flag_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrReduceScatterFlag); + } #ifdef ENABLE_MPI if (reduce_scatter_flag_) { size_t gatherv2_out_lens = 1; @@ -65,7 +67,9 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true"; } #endif - offset_ = AnfAlgo::GetNodeAttr(kernel_node, "offset"); + if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) { + offset_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrOffset); + } CPUKernelUtils::ExpandDimsTo4(&input_shape_); CPUKernelUtils::ExpandDimsTo4(&output_shape_); } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index d10d5830fa..a5ec56cb2f 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -223,6 +223,8 @@ constexpr auto kAttrNumSplit = "num_split"; constexpr auto kAttrOutputNum = "output_num"; constexpr auto kAttrSizeSplits = "size_splits"; constexpr auto kAttrOutputDefault = "output_default"; +constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag"; +constexpr auto kAttrOffset = "offset"; // attr value constexpr auto kValueTargetSwitch = "target_switch";