diff --git a/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc b/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc index 1160ba57b7..0d49846bf7 100644 --- a/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc +++ b/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc @@ -179,8 +179,8 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec return result; } -bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t data_num, - const std::string &op_type, float *output) { +bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t input_data_num, + size_t output_size, const std::string &op_type, float *output) { int scatter_index = GetScatterIndex(rank_id_, ranks_group); auto group = AddGroup(ranks_group); if (group == MPI_GROUP_NULL) { @@ -193,7 +193,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t data_num, const std::string &op_type = kOpTypeSum); - bool ReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t data_num, - const std::string &op_type = kOpTypeSum, float *output = nullptr); - bool AllGather(float *input, float *output, const std::vector &ranks_group, size_t data_num); + bool ReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t input_data_num, + size_t output_size, const std::string &op_type = kOpTypeSum, + float *output = nullptr); + bool AllGather(const float *input, float *output, const std::vector &ranks_group, size_t data_num); private: MPIAdapter(); diff --git a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc index 043fb7a0df..abb0c65d27 100644 --- a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc @@ -26,21 +26,11 @@ constexpr auto kRanksGroup = "group"; constexpr auto kAllGatherInputNum = 1; } // namespace -AllGatherCPUKernel::AllGatherCPUKernel() : input_data_number_(0) {} - void AllGatherCPUKernel::InitKernel(const CNodePtr &kernel_node) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != kAllGatherInputNum) { MS_LOG(EXCEPTION) << "allgather input num:" << input_num; } - for (size_t i = 0; i < input_num; ++i) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); - size_t count = 1; - for (size_t j = 0; j < shape.size(); j++) { - count *= IntToSize(shape[j]); - } - input_data_number_ += count; - } auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup); if (ranks_group != nullptr) { @@ -55,8 +45,9 @@ bool AllGatherCPUKernel::Launch(const std::vector &inputs, const std::vector &outputs) { auto input_addr = reinterpret_cast(inputs[0]->addr); auto output_addr = reinterpret_cast(outputs[0]->addr); + auto input_data_num = inputs[0]->size / sizeof(float); - return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_number_); + return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_num); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h index 6c06d3e851..94180fa89b 100644 --- a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h @@ -24,7 +24,7 @@ namespace mindspore { namespace kernel { class AllGatherCPUKernel : public CPUKernel { public: - AllGatherCPUKernel(); + AllGatherCPUKernel() = default; ~AllGatherCPUKernel() override = default; void InitKernel(const CNodePtr &kernel_node) override; @@ -33,7 +33,6 @@ class AllGatherCPUKernel : public CPUKernel { const std::vector &outputs) override; private: - size_t input_data_number_; std::vector ranks_group_; }; diff --git a/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc index 69cea98a47..fd8a74eb6b 100644 --- a/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc @@ -24,18 +24,9 @@ namespace { constexpr auto kRanksGroup = "group"; } // namespace -ReduceScatterCPUKernel::ReduceScatterCPUKernel() : output_data_number_(0), op_type_(device::cpu::kOpTypeSum) {} +ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(device::cpu::kOpTypeSum) {} void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) { - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - for (size_t i = 0; i < output_num; ++i) { - auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i); - size_t size = 1; - for (size_t j = 0; j < shape.size(); j++) { - size *= IntToSize(shape[j]); - } - output_data_number_ += size; - } auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); if (op != nullptr) { op_type_ = GetValue(op); @@ -54,8 +45,9 @@ bool ReduceScatterCPUKernel::Launch(const std::vector &input const std::vector &outputs) { auto input_addr = reinterpret_cast(inputs[0]->addr); auto output_addr = reinterpret_cast(outputs[0]->addr); + auto output_data_num = outputs[0]->size / sizeof(float); - return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_number_, + return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_); } } // namespace kernel diff --git a/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h index cebf649954..c3bfe571a4 100644 --- a/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h @@ -33,7 +33,6 @@ class ReduceScatterCPUKernel : public CPUKernel { const std::vector &outputs) override; private: - size_t output_data_number_; std::string op_type_; std::vector ranks_group_; };