diff --git a/mindspore/ccsrc/kernel/CMakeLists.txt b/mindspore/ccsrc/kernel/CMakeLists.txt index 993768b7f7..88960ad355 100644 --- a/mindspore/ccsrc/kernel/CMakeLists.txt +++ b/mindspore/ccsrc/kernel/CMakeLists.txt @@ -27,7 +27,7 @@ if (ENABLE_CPU) list(REMOVE_ITEM CPU_SRC_LIST "cpu/reduce_scatter_cpu_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/embedding_look_up_comm_grad_cpu_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/embedding_look_up_cpu_kernel.cc") - list(REMOVE_ITEM CPU_SRC_LIST "cpu/subscalar_cpu_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/sub_cpu_kernel.cc") endif () endif () diff --git a/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc index e56afaf9b5..b4eaa82bc6 100644 --- a/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc @@ -48,19 +48,14 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { auto output_type = AnfAlgo::GetOutputInferDataType(kernel_node, 0); MS_LOG(DEBUG) << "output type: " << output_type; - int axis = AnfAlgo::GetNodeAttr(kernel_node, "axis"); - MS_LOG(DEBUG) << "axis: " << axis; - if (axis_ < 0) { - axis = axis + SizeToInt(input_shape_.size()); - } - axis_ = 4 - input_shape_.size() + axis; + axis_ = 4 - input_shape_.size(); MS_LOG(DEBUG) << "axis_: " << axis_; reduce_scatter_flag_ = AnfAlgo::GetNodeAttr(kernel_node, "reduce_scatter_flag"); MS_LOG(DEBUG) << "reduce_scatter_flag: " << reduce_scatter_flag_; if (reduce_scatter_flag_) { size_t gatherv2_out_lens = 1; for (int i = 0; i < SizeToInt(input_shape_.size()); i++) { - if (i == axis) { + if (i == 0) { for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) { MS_LOG(DEBUG) << "gatherv2 out shape: " << indices_shape_[j]; gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j]; @@ -76,7 +71,10 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { if (gather_v2_out_ == nullptr) { MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_; } - memset_s(gather_v2_out_, gatherv2_out_lens_, 0, gatherv2_out_lens_); + auto ret = memset_s(gather_v2_out_, gatherv2_out_lens_, 0, gatherv2_out_lens_); + if (ret != 0) { + MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed"; + } split_num_ = AnfAlgo::GetNodeAttr(kernel_node, "split_num"); MS_LOG(DEBUG) << "split_num: " << split_num_; @@ -99,6 +97,12 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector &inp auto output_addr = reinterpret_cast(outputs[0]->addr); MS_LOG(DEBUG) << "output addr: " << output_addr << "output size: " << outputs[0]->size; float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast(gather_v2_out_) : output_addr; + if (!reduce_scatter_flag_) { + auto ret = memset_s(gather_out_addr, outputs[0]->size, 0, outputs[0]->size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset out buff failed"; + } + } MS_LOG(DEBUG) << "gatherv2 out addr: " << gather_out_addr; size_t dim0 = input_shape_[0]; size_t dim1 = input_shape_[1]; @@ -149,10 +153,10 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector &inp return true; } -void memcpy_task(std::vector mem_dest_addr_list, std::vector mem_src_addr_list, size_t start, +void memcpy_task(std::vector *mem_dest_addr_list, std::vector *mem_src_addr_list, size_t start, size_t end, size_t lens) { for (size_t i = start; i < end; i++) { - auto ret = memcpy_s(mem_dest_addr_list[i], lens, mem_src_addr_list[i], lens); + auto ret = memcpy_s((*mem_dest_addr_list)[i], lens, (*mem_src_addr_list)[i], lens); if (ret != EOK) { MS_LOG(EXCEPTION) << "memery copy failed."; } @@ -204,7 +208,7 @@ void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector break; } auto end = (start + ones_copy_lens) > memcpy_lens ? memcpy_lens : start + ones_copy_lens; - threads[i] = std::thread(memcpy_task, mem_dest_addr_list, mem_src_addr_list, start, end, lens); + threads[i] = std::thread(memcpy_task, &mem_dest_addr_list, &mem_src_addr_list, start, end, lens); start = start + ones_copy_lens; } for (size_t j = 0; j < i; j++) { diff --git a/mindspore/ccsrc/kernel/cpu/subscalar_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.cc similarity index 76% rename from mindspore/ccsrc/kernel/cpu/subscalar_cpu_kernel.cc rename to mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.cc index 435154561b..9fe36fba81 100644 --- a/mindspore/ccsrc/kernel/cpu/subscalar_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.cc @@ -14,14 +14,20 @@ * limitations under the License. */ #include -#include "kernel/cpu/subscalar_cpu_kernel.h" +#include "kernel/cpu/sub_cpu_kernel.h" #include "device/cpu/cpu_device_address.h" namespace mindspore { namespace kernel { -void SubscalarCPUKernel::InitKernel(const CNodePtr &kernel_node) { - offset_ = AnfAlgo::GetNodeAttr(kernel_node, "input_y"); - MS_LOG(DEBUG) << "offset: " << offset_; +void SubCPUKernel::InitKernel(const CNodePtr &kernel_node) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + if (shape.size() == 1) { + if (shape[0] != 1) { + MS_LOG(EXCEPTION) << "input 1 only support scalar"; + } + } else { + MS_LOG(EXCEPTION) << "input 1 only support scalar"; + } } void sub_task(int *in_addr, int *out_addr, size_t lens, int offset) { @@ -30,9 +36,9 @@ void sub_task(int *in_addr, int *out_addr, size_t lens, int offset) { } } -bool SubscalarCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { +bool SubCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { #if defined(_WIN32) || defined(_WIN64) auto start_time = std::chrono::steady_clock::now(); #else @@ -41,6 +47,8 @@ bool SubscalarCPUKernel::Launch(const std::vector &inputs, #endif auto input_addr = reinterpret_cast(inputs[0]->addr); auto output_addr = reinterpret_cast(outputs[0]->addr); + offset_ = *reinterpret_cast(inputs[1]->addr); + MS_LOG(INFO) << "offset: " << offset_; auto lens = inputs[0]->size / sizeof(int); if (lens < 10000) { for (size_t i = 0; i < lens; i++) { @@ -73,7 +81,7 @@ bool SubscalarCPUKernel::Launch(const std::vector &inputs, (void)gettimeofday(&end_time, nullptr); uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); time += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "SubscalarCPUKernel, used time: " << time << " us"; + MS_LOG(INFO) << "SubCPUKernel, used time: " << time << " us"; #endif return true; } diff --git a/mindspore/ccsrc/kernel/cpu/subscalar_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.h similarity index 70% rename from mindspore/ccsrc/kernel/cpu/subscalar_cpu_kernel.h rename to mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.h index bd70b075ee..5530962d7f 100644 --- a/mindspore/ccsrc/kernel/cpu/subscalar_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SUBSCALAR_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SUBSCALAR_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ #include #include #include "kernel/cpu/cpu_kernel.h" @@ -22,10 +22,10 @@ namespace mindspore { namespace kernel { -class SubscalarCPUKernel : public CPUKernel { +class SubCPUKernel : public CPUKernel { public: - SubscalarCPUKernel() : offset_(0) {} - ~SubscalarCPUKernel() override = default; + SubCPUKernel() : offset_(0) {} + ~SubCPUKernel() override = default; void InitKernel(const CNodePtr &kernel_node) override; @@ -36,9 +36,8 @@ class SubscalarCPUKernel : public CPUKernel { int offset_; }; -MS_REG_CPU_KERNEL(Subscalar, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SubscalarCPUKernel); +MS_REG_CPU_KERNEL(Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SubCPUKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SUBSCALAR_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_