From 6541b96c40b6dd10c60bc4a612701cfba050278f Mon Sep 17 00:00:00 2001 From: liujunzhu Date: Fri, 5 Mar 2021 10:03:31 +0800 Subject: [PATCH] Add communication parallel mode. --- mindspore/ccsrc/frontend/parallel/context.cc | 15 +++++++ mindspore/ccsrc/frontend/parallel/context.h | 8 ++++ mindspore/ccsrc/pipeline/jit/init.cc | 2 + .../device/ascend/ascend_stream_assign.cc | 41 ++++++++++++++++++- mindspore/parallel/_auto_parallel_context.py | 37 +++++++++++++++-- .../test_set_auto_parallel_context.py | 10 ++++- 6 files changed, 108 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index 9e832b09be..f42d94152b 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -33,6 +33,9 @@ std::vector PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRI AUTO_PARALLEL}; std::vector STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING}; +std::vector COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL, + NO_GROUP_PARALLEL}; + std::shared_ptr ParallelContext::inst_context_ = nullptr; std::shared_ptr ParallelContext::GetInstance() { @@ -65,6 +68,7 @@ void ParallelContext::Reset() { strategy_search_mode_ = DYNAMIC_PROGRAMMING; pipeline_stage_split_num_ = 1; grad_accumulation_step_ = 1; + communi_parallel_mode_ = ALL_GROUP_PARALLEL; } void ParallelContext::set_device_num(int64_t device_num) { @@ -152,6 +156,17 @@ const std::vector ParallelContext::GetAllReduceFusionSplitSizes(const return {}; } +bool ParallelContext::set_communi_parallel_mode(const std::string &communi_parallel_mode) { + auto iter = std::find(COMMUNI_PARALLEL_MODE_LIST.begin(), COMMUNI_PARALLEL_MODE_LIST.end(), communi_parallel_mode); + if (iter == COMMUNI_PARALLEL_MODE_LIST.end()) { + MS_LOG(INFO) << "Invalid communication parallel mode:" << communi_parallel_mode; + return false; + } + + communi_parallel_mode_ = communi_parallel_mode; + return true; +} + // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index 910c371266..e63d45835d 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -46,6 +46,10 @@ constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; constexpr char TRAINING[] = "training"; constexpr char ACCUMULATION[] = "accumulation"; +constexpr char ALL_GROUP_PARALLEL[] = "all_group_parallel"; +constexpr char SAME_SERVER_GROUP_PARALLEL[] = "same_server_group_parallel"; +constexpr char NO_GROUP_PARALLEL[] = "no_group_parallel"; + class ParallelContext { public: ~ParallelContext() = default; @@ -112,6 +116,9 @@ class ParallelContext { } bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } + bool set_communi_parallel_mode(const std::string &communi_parallel_mode); + std::string communi_parallel_mode() const { return communi_parallel_mode_; } + void Reset(); void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph); void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, @@ -144,6 +151,7 @@ class ParallelContext { std::string group_ckpt_save_file_; bool enable_parallel_optimizer_; bool init_param_shape_; + std::string communi_parallel_mode_; }; } // namespace parallel diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 57349f6502..ac130e960f 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -169,6 +169,8 @@ PYBIND11_MODULE(_c_expression, m) { "Set enable/disable parallel optimizer.") .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer, "Get enable/disable parallel optimizer.") + .def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.") + .def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.") .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); (void)py::class_>(m, "CostModelContext") diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index 7cfe92d0c7..0b70efb9ff 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -40,6 +40,7 @@ namespace ascend { namespace { constexpr uint32_t kDeviceNumOfServer = 8; constexpr uint32_t kDeviceNumThreshold = 1024; +const char kDefaultGroup[] = "__default_group"; constexpr uint32_t kMaxStreamNum = 1024; constexpr uint32_t kHcomSecondaryStreamNum = 3; @@ -60,13 +61,48 @@ bool IsSameServer(const std::vector &rank_ids) { return ((max - min < kDeviceNumOfServer) && (min / kDeviceNumOfServer == max / kDeviceNumOfServer)); } +string DoGetHcomGroup(const string &original_group) { + string communi_parallel_mode = parallel::ParallelContext::GetInstance()->communi_parallel_mode(); + + if (communi_parallel_mode == parallel::ALL_GROUP_PARALLEL) { + return original_group; + } + + if (communi_parallel_mode == parallel::NO_GROUP_PARALLEL) { + return kDefaultGroup; + } + + MS_EXCEPTION_IF_NULL(parallel::g_device_manager); + auto group_info = parallel::g_device_manager->group_info(); + for (const auto &info : group_info) { + if (info.first != original_group) { + continue; + } + + const auto &rank_ids = info.second; + if (IsSameServer(rank_ids)) { + return original_group; + } else { + return kDefaultGroup; + } + } + + // world group is not in group_info. + return kDefaultGroup; +} + string GetHcomGroup(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) { MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute."; } - return AnfAlgo::GetNodeAttr(cnode, kAttrGroup); + auto group_name = AnfAlgo::GetNodeAttr(cnode, kAttrGroup); + auto new_group = DoGetHcomGroup(group_name); + MS_LOG_INFO << "hcom node: " << cnode->fullname_with_scope() << ", old group: " << group_name + << ", new group: " << new_group; + + return new_group; } uint32_t GetHcomTaskNum(const CNodePtr &cnode) { @@ -167,6 +203,9 @@ StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, u void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) { if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) { + MS_LOG(INFO) << "Communication parallel mode: " << parallel::ParallelContext::GetInstance()->communi_parallel_mode() + << "."; + Reset(); SetLoopSink(); ReorderIndependentOrders(graph_ptr); diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index ebc5dcd153..83b0a1731a 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -480,6 +480,26 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_enable_parallel_optimizer() + def set_communi_parallel_mode(self, communi_parallel_mode): + """ + Set communication parallel mode. + + Args: + communi_parallel_mode (str): The communication parallel mode. + + Raises: + ValueError: If parallel mode is not supported. + """ + self.check_context_handle() + ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode) + if ret is False: + raise ValueError("Communication parallel mode does not support {}".format(communi_parallel_mode)) + + def get_communi_parallel_mode(self): + """Get communication parallel mode.""" + self.check_context_handle() + return self._context_handle.get_communi_parallel_mode() + def reset(self): """Reset all settings.""" self.check_context_handle() @@ -518,7 +538,8 @@ _set_auto_parallel_context_func_map = { "full_batch": auto_parallel_context().set_full_batch, "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, - "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices} + "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices, + "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode} _get_auto_parallel_context_func_map = { @@ -536,14 +557,16 @@ _get_auto_parallel_context_func_map = { "full_batch": auto_parallel_context().get_full_batch, "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer, "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step, - "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices} + "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices, + "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode} @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, - grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str) + grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str, + communi_parallel_mode=str) def _set_auto_parallel_context(**kwargs): """ @@ -592,6 +615,14 @@ def _set_auto_parallel_context(**kwargs): the devices are distributed alone the pipeline. The total devices will be divided into 'pipeline_stags' stages. This currently could only be used when parall mode semi_auto_parallel is enabled. Default: 0 + communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel", + "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel". + + - all_group_parallel: All communication groups are in parallel. + + - same_server_group_parallel: Only the communication groups within the same server are parallel. + + - no_group_parallel: All communication groups are not parallel. Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/tests/ut/python/parallel/test_set_auto_parallel_context.py b/tests/ut/python/parallel/test_set_auto_parallel_context.py index 5f879064b7..5f6401718f 100644 --- a/tests/ut/python/parallel/test_set_auto_parallel_context.py +++ b/tests/ut/python/parallel/test_set_auto_parallel_context.py @@ -21,19 +21,22 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context def test_set_auto_parallel_context(): context.set_auto_parallel_context(device_num=4, global_rank=3, gradients_mean=True, gradient_fp32_sync=False, - parallel_mode="auto_parallel", parameter_broadcast=False) + parallel_mode="auto_parallel", parameter_broadcast=False, + communi_parallel_mode="same_server_group_parallel") device_num = context.get_auto_parallel_context("device_num") global_rank = context.get_auto_parallel_context("global_rank") gradients_mean = context.get_auto_parallel_context("gradients_mean") gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync") parallel_mode = context.get_auto_parallel_context("parallel_mode") parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") + communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode") assert device_num == 4 assert global_rank == 3 assert gradients_mean assert not gradient_fp32_sync assert parallel_mode == "auto_parallel" assert not parameter_broadcast + assert communi_parallel_mode == "same_server_group_parallel" auto_parallel_context().set_device_num(4) device_num = auto_parallel_context().get_device_num() @@ -77,6 +80,9 @@ def test_set_auto_parallel_context(): with pytest.raises(ValueError): set_algo_parameters(tensor_slice_align_size=1025) + with pytest.raises(ValueError): + context.set_auto_parallel_context(communi_parallel_mode="wrong_mode") + context.set_auto_parallel_context(enable_parallel_optimizer=True) assert context.get_auto_parallel_context("enable_parallel_optimizer") assert not auto_parallel_context().get_all_reduce_fusion_split_indices() @@ -98,6 +104,7 @@ def test_reset_auto_parallel_context(): device_num_is_set = auto_parallel_context().get_device_num_is_set() parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() stage = auto_parallel_context().get_pipeline_stages() + communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode") assert device_num == 1 assert global_rank == 0 @@ -108,3 +115,4 @@ def test_reset_auto_parallel_context(): assert not device_num_is_set assert not parameter_broadcast_is_set assert stage == 1 + assert communi_parallel_mode == "all_group_parallel"