From 7265a49400c183c46927fad81f0aa3a29738a40c Mon Sep 17 00:00:00 2001 From: alouhahaha Date: Tue, 1 Dec 2020 16:02:24 +0800 Subject: [PATCH] add allgather fusion --- .../kernel_compiler/hccl/hccl_kernel.cc | 18 ++++- .../backend/kernel_compiler/hccl/hcom_util.cc | 7 +- .../enhancer/concat_outputs_for_all_gather.cc | 66 +++++++------------ .../optimizer/pass/communication_op_fusion.cc | 43 ++++++++---- 4 files changed, 74 insertions(+), 60 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc index 9ad10704f0..28a19fa0d7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -133,8 +133,22 @@ const std::vector &HcclKernel::GetOutputSizeList() const { if (!output_size_list_.empty()) { return output_size_list_; } - for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) { - if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_output_shape_list_[i], &size)) { + auto cnode = anf_node_->cast(); + auto op_name = AnfAlgo::GetCNodeName(cnode); + int64_t rank_size = 1; + if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) { + rank_size = AnfAlgo::GetNodeAttr(cnode, kAttrRankSize); + } + int64_t fusion = 0; + if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode)) { + fusion = AnfAlgo::GetNodeAttr(cnode, kAttrFusion); + } + ulong loop_size = hccl_data_type_list_.size(); + if (op_name == kAllGatherOpName && fusion >= 1) { + loop_size *= rank_size; + } + for (ulong i = 0; i < loop_size; ++i) { + if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) { MS_LOG(ERROR) << "GetHcclOpOutputSize failed"; } output_size_list_.push_back(size); diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc index 8d1836f3b8..7e883a8b88 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc @@ -127,7 +127,12 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vectorcast(); + if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr(anf_node, kAttrFusion)) { + block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; + } else { + block_size = input_size; + } } else { block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc index 2c21292e79..7589ed1d12 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc @@ -20,58 +20,33 @@ namespace mindspore { namespace opt { -namespace { -void AddOutputs(const AnfNodePtr &node, int64_t rank_size) { - MS_EXCEPTION_IF_NULL(node); - auto origin_abstract = node->abstract(); - MS_EXCEPTION_IF_NULL(origin_abstract); - auto tuple_abstract = origin_abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - auto &origin_abstracts = tuple_abstract->elements(); - AbstractBasePtrList abstract_list; - std::vector outputs_device_type; - std::vector outputs_device_format; - for (int64_t i = 0; i < rank_size; ++i) { - for (size_t j = 0; j < origin_abstracts.size(); ++j) { - abstract_list.push_back(origin_abstracts[j]); - outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(node, j)); - outputs_device_format.push_back(AnfAlgo::GetOutputFormat(node, j)); - } - } - // Update abstract - auto new_abstracts = std::make_shared(abstract_list); - node->set_abstract(new_abstracts); - // Update kernel build info - auto builder = - std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); - builder->SetOutputsDeviceType(outputs_device_type); - builder->SetOutputsFormat(outputs_device_format); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); -} -} // namespace - AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const std::vector &new_tuple_getitems, int64_t rank_size) const { MS_EXCEPTION_IF_NULL(func_graph); - std::vector make_tuple_inputs; + std::vector make_tuple_inputs{NewValueNode(std::make_shared(prim::kPrimMakeTuple->name()))}; size_t inputs_size = AnfAlgo::GetInputTensorNum(node); for (size_t i = 0; i < inputs_size; ++i) { - for (size_t j = 0, idx = i; j < LongToSize(rank_size); ++j, idx += inputs_size) { - std::vector concat_inputs{NewValueNode(std::make_shared(prim::kPrimConcat->name()))}; + std::vector concat_inputs{NewValueNode(std::make_shared(prim::kPrimConcat->name()))}; + for (size_t j = 0, idx = i; j < IntToSize(rank_size); ++j, idx += inputs_size) { concat_inputs.push_back(new_tuple_getitems[idx]); - auto concat = func_graph->NewCNode(concat_inputs); - MS_EXCEPTION_IF_NULL(concat); - MS_EXCEPTION_IF_NULL(new_tuple_getitems[idx]); - concat->set_abstract(new_tuple_getitems[idx]->abstract()); - AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast(0)), concat); - AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); - std::vector dyn_input_size{rank_size}; - AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); - kernel_select_->SelectKernel(concat); - make_tuple_inputs.push_back(concat); } + auto concat = func_graph->NewCNode(concat_inputs); + MS_EXCEPTION_IF_NULL(concat); + MS_EXCEPTION_IF_NULL(new_tuple_getitems[i]); + auto dtypes = {AnfAlgo::GetOutputInferDataType(new_tuple_getitems[i], 0)}; + std::vector shape = AnfAlgo::GetOutputInferShape(new_tuple_getitems[i], 0); + shape[0] *= rank_size; + std::vector> shapes = {shape}; + AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get()); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast(0)), concat); + AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); + std::vector dyn_input_size{rank_size}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); + kernel_select_->SelectKernel(concat); + make_tuple_inputs.push_back(concat); } + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); return make_tuple; } @@ -94,8 +69,11 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra if (fusion <= 0) { return nullptr; } + if (AnfAlgo::HasNodeAttr("fused", cnode)) { + return nullptr; + } + AnfAlgo::SetNodeAttr("fused", MakeValue(true), node); auto rank_size = AnfAlgo::GetNodeAttr(node, kAttrRankSize); - AddOutputs(node, rank_size); std::vector new_outputs; CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs); return InsertConcatForOutput(func_graph, node, new_outputs, rank_size); diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc index a992217d40..db356f7351 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc @@ -46,15 +46,23 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &co kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; for (size_t idx = start_index; idx <= end_index; ++idx) { auto cnode = communication_op_info.communication_op_nodes[idx]; + int64_t rank_size = 1; + if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode) && AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName) { + rank_size = AnfAlgo::GetNodeAttr(cnode, kAttrRankSize); + } MS_EXCEPTION_IF_NULL(cnode); for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { - outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); - outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); - outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); + for (size_t rank_index = 0; rank_index < IntToSize(rank_size); ++rank_index) { + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); + outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); + std::vector shape = AnfAlgo::GetOutputInferShape(cnode, output_index); + shape[0] /= rank_size; + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); + } } builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); @@ -182,18 +190,27 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr auto kernel_info = std::make_shared(); MS_EXCEPTION_IF_NULL(kernel_info); fused_node->set_kernel_info(kernel_info); - AbstractBasePtrList abstract_list; - for (size_t idx = start_index; idx <= end_index; ++idx) { - auto cnode = communication_op_info.communication_op_nodes[idx]; - MS_EXCEPTION_IF_NULL(cnode); - abstract_list.push_back(cnode->abstract()); + auto final_node = communication_op_info.communication_op_nodes[end_index]; + size_t node_num = end_index - start_index + 1; + int64_t rank_size = 1; + if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node) && AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName) { + rank_size = AnfAlgo::GetNodeAttr(final_node, kAttrRankSize); + } + size_t output_num = node_num * rank_size; + std::vector dtypes(output_num, AnfAlgo::GetOutputInferDataType(final_node, 0)); + std::vector> shapes; + for (size_t i = 0; i < IntToSize(rank_size); ++i) { + for (size_t idx = start_index; idx <= end_index; ++idx) { + auto cnode = communication_op_info.communication_op_nodes[idx]; + MS_EXCEPTION_IF_NULL(cnode); + std::vector shape = AnfAlgo::GetOutputInferShape(cnode, 0); + shape[0] /= rank_size; + shapes.push_back(shape); + } } + AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, fused_node.get()); auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get()); - auto abstract_tuple = std::make_shared(abstract_list); - MS_EXCEPTION_IF_NULL(abstract_tuple); - fused_node->set_abstract(abstract_tuple); - auto final_node = communication_op_info.communication_op_nodes[end_index]; AnfAlgo::CopyNodeAttr(kAttrFusion, final_node, fused_node); AnfAlgo::CopyNodeAttr(kAttrOp, final_node, fused_node); AnfAlgo::CopyNodeAttr(kAttrGroup, final_node, fused_node);