From 278e82a84983fe4551abc7fd7e0fa9535308e4b7 Mon Sep 17 00:00:00 2001 From: yangzhenzhang <285824651@qq.com> Date: Tue, 17 Nov 2020 20:14:42 +0800 Subject: [PATCH] update pipeline parallel --- mindspore/ccsrc/frontend/parallel/context.cc | 5 +- mindspore/ccsrc/frontend/parallel/context.h | 6 +- .../ccsrc/frontend/parallel/device_manager.cc | 42 ++++---- .../ccsrc/frontend/parallel/step_parallel.cc | 102 ++++++------------ mindspore/context.py | 3 +- .../test_set_auto_parallel_context.py | 2 +- 6 files changed, 57 insertions(+), 103 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index a9c2b3efff..7d82feb6ce 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -63,8 +63,7 @@ void ParallelContext::Reset() { all_reduce_fusion_split_indices_.clear(); all_reduce_fusion_split_sizes_.clear(); strategy_search_mode_ = DYNAMIC_PROGRAMMING; - stages_.clear(); - pipeline_stage_split_num_ = 0; + pipeline_stage_split_num_ = 1; } void ParallelContext::set_device_num(int64_t device_num) { @@ -87,8 +86,6 @@ void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_rep void ParallelContext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_num; } -void ParallelContext::set_stage(const std::vector &stages) { stages_ = stages; } - bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); if (iter == PARALLEL_MODE_LIST.end()) { diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index ff6a970a54..b58b9b2371 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -70,9 +70,6 @@ class ParallelContext { void set_pipeline_stage_split_num(const int64_t stages); int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; } - void set_stage(const std::vector &stages); - std::vector stage() const { return stages_; } - void set_global_rank(int64_t global_rank); int64_t global_rank() const { return global_rank_; } @@ -121,8 +118,7 @@ class ParallelContext { int64_t global_rank_; std::string parallel_mode_; std::string strategy_search_mode_; - std::vector stages_; - int64_t pipeline_stage_split_num_ = 0; + int64_t pipeline_stage_split_num_; bool parameter_broadcast_; bool device_num_is_set_; bool global_rank_is_set_; diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.cc b/mindspore/ccsrc/frontend/parallel/device_manager.cc index 5ec60d8f79..52010a4f41 100644 --- a/mindspore/ccsrc/frontend/parallel/device_manager.cc +++ b/mindspore/ccsrc/frontend/parallel/device_manager.cc @@ -54,44 +54,44 @@ bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &back MS_LOG(ERROR) << "Invalid backend: " << backend; return false; } + if (stage.empty()) { + MS_LOG(ERROR) << "The size of stage must be positive"; + return false; + } RankList devices, stage_map; for (int64_t i = 0; i < device_num; ++i) { devices.push_back(i); } - if (stage.size()) { - int64_t summed_value = 0; - for (auto begin = stage.begin(); begin != stage.end(); ++begin) { - if (*begin <= 0) { - MS_LOG(ERROR) << "The value in the pipeline stages should be positive value"; - return false; - } - summed_value += *begin; - stage_map.push_back(*begin); - } - - if (summed_value != device_num) { - MS_LOG(ERROR) << "The sum of the pipeline stage :" << summed_value << " is not equal to the device_num " - << device_num; + int64_t summed_value = 0; + for (auto begin = stage.begin(); begin != stage.end(); ++begin) { + if (*begin <= 0) { + MS_LOG(ERROR) << "The value in the pipeline stages should be positive value"; return false; } - } else { - stage_map.push_back(device_num); + summed_value += *begin; + stage_map.push_back(*begin); } - for (auto &y : stage_map) { - MS_LOG(DEBUG) << "Obtained stage id :" << y; + if (summed_value != device_num) { + MS_LOG(ERROR) << "The sum of the pipeline stage :" << summed_value << " is not equal to the device_num " + << device_num; + return false; + } + + for (auto &ele : stage_map) { + MS_LOG(DEBUG) << "Obtained stage id: " << ele; } g_device_manager = std::make_shared(); if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) { MS_LOG(INFO) << "Device initialization succeeds."; return true; - } else { - MS_LOG(ERROR) << "Device initialization fails."; - return false; } + + MS_LOG(ERROR) << "Device initialization fails."; + return false; } void CheckGlobalDeviceManager() { diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index a3a76b40b1..c66b3bb91e 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1125,16 +1125,7 @@ OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAtt StrategyPtr ExtractStrategy(std::unordered_map attrs) { ValueTuplePtr var = attrs[STRATEGY]->cast(); StrategyPtr strategyPtr; - std::vector stages = ParallelContext::GetInstance()->stage(); - auto res = attrs.find(STAGE_ATTR); - int64_t stage_id = 0; - if (res != attrs.end()) { - stage_id = GetValue(res->second); - } - if (stage_id && stages.empty()) { - MS_LOG(ERROR) << "Find stage id:" << stage_id << " but the pipeline_stages is 0."; - return nullptr; - } + int64_t stage_id = g_device_manager->stage_id(); MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString(); if (var == nullptr) { @@ -1152,11 +1143,11 @@ StrategyPtr ExtractStrategy(std::unordered_map attrs) { [](const ValuePtr &value) { return static_cast(GetValue(value)); }); strategy.push_back(dim); } else { - MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequence"; + MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence"; } } if (strategy.empty()) { - MS_LOG(EXCEPTION) << "ExtractStrategy:failed to extract strategy"; + MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy"; } strategyPtr = NewStrategy(stage_id, strategy); } @@ -1663,30 +1654,6 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { (void)prim->SetAttrs(attrs_temp); } } -// This function aims to check the valid rank and stage in the operations -// If the rank is not valid for the given stage, we chose not to init the strategy of the operation -// For example stage is [4, 4], and the group_list [[0,1,2,3],[4,5,6,7]] -// For stage 0, we require the rank_id is in [0,1,2,3] -Status ValidRankCheck(int32_t global_rank, int32_t strategy_stage) { - RankList local_group_list = g_device_manager->GetDeviceListByStageId(strategy_stage); - int32_t target = global_rank; - if (std::any_of(local_group_list.begin(), local_group_list.end(), [target](int32_t a) { return a == target; })) { - return Status::SUCCESS; - } - - return Status::FAILED; -} - -Status ValidStageCheck(const std::vector &stages, int32_t strategy_stage) { - if (stages.size() > 0) { - if (strategy_stage >= 0 && strategy_stage < (int32_t)stages.size()) { - return Status::SUCCESS; - } - return Status::FAILED; - } else { - return Status::SUCCESS; - } -} // find previous parallel care node. bool FindPreNodes(const AnfNodePtr &node, vector *unique_ids) { @@ -1781,9 +1748,7 @@ void ExtractInformation(const std::vector &all_nodes, bool is_traini FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; } - // Get global rank after the checkpoint? - int64_t global_rank = ParallelContext::GetInstance()->global_rank(); - std::vector stages = ParallelContext::GetInstance()->stage(); + for (auto &node : all_nodes) { auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { @@ -1848,18 +1813,7 @@ void ExtractInformation(const std::vector &all_nodes, bool is_traini if (is_last_nodes && full_batch) { SetLastNodeStrategy(strategyPtr); } - (*operator_).set_stage_id(strategyPtr->GetInputStage()); - MS_LOG(INFO) << "Extract stage id for op " << prim->name() << " is " << (*operator_).stage_id(); - if (ValidStageCheck(stages, (*operator_).stage_id()) == FAILED) { - MS_LOG(ERROR) << "Find stage " << strategyPtr->GetInputStage() << " for operator " << prim->name() - << " exceeds the global stage size " << stages.size() << '.'; - return; - } - // If the strategy is not valid for the given global rank, then we skip the Init of the strategy - if (ValidRankCheck(global_rank, (*operator_).stage_id()) == FAILED) { - MS_LOG(INFO) << "Find global exceeds the range of the stage, skip the strategy init for operator " - << prim->name(); - } else if (operator_->Init(strategyPtr) == FAILED) { + if (operator_->Init(strategyPtr) == FAILED) { MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; } cnode->set_user_data(operator_); @@ -2800,7 +2754,6 @@ Status ParallelInit() { int64_t device_num = ParallelContext::GetInstance()->device_num(); int64_t global_rank = ParallelContext::GetInstance()->global_rank(); int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num(); - std::vector stages = ParallelContext::GetInstance()->stage(); std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); @@ -2814,29 +2767,15 @@ Status ParallelInit() { world_group = NCCL_WORLD_GROUP; communication_backend = NCCL_BACKEND; } else { - MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; - } - - if (device_num <= 0) { - MS_LOG(ERROR) << "Invalid device num " << device_num << " , expected a positive device number"; + MS_LOG(ERROR) << "Invalid communication backend: " << backend; return FAILED; } - if (split_stage_num > 0) { - if (device_num % split_stage_num != 0) { - MS_LOG(ERROR) << "Device num " << device_num << " can't be divided by stage num " << split_stage_num - << " , as we support only extract devision now"; - return FAILED; - } - for (int i = 0; i < split_stage_num; i++) { - stages.push_back(device_num / split_stage_num); - } - } else if (split_stage_num < 0) { - MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << " , expected a positive stage number"; + + if (split_stage_num <= 0) { + MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << ", expected a positive stage number"; return FAILED; } - ParallelContext::GetInstance()->set_stage(stages); - uint32_t world_rank_size = 0; if (!ParallelContext::GetInstance()->device_num_is_set()) { if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) { @@ -2855,7 +2794,28 @@ Status ParallelInit() { MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; } - if (!stages.empty() && parallel_mode != SEMI_AUTO_PARALLEL) { + if ((device_num <= 0) || (device_num > MAX_DEVICE_NUM)) { + MS_LOG(ERROR) << "Invalid device num " << device_num; + return FAILED; + } + + // the device_num maybe get from communication interface + if (device_num % split_stage_num != 0) { + MS_LOG(ERROR) << "Device num " << device_num << " can't be divided by stage num " << split_stage_num; + return FAILED; + } + + if ((global_rank < 0) || (global_rank >= device_num)) { + MS_LOG(ERROR) << "Global rank " << global_rank << " is out of range, the device num is " << device_num; + return FAILED; + } + + std::vector stages; + for (int i = 0; i < split_stage_num; i++) { + stages.push_back(device_num / split_stage_num); + } + + if ((split_stage_num > 1) && (parallel_mode != SEMI_AUTO_PARALLEL)) { MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << SEMI_AUTO_PARALLEL; return FAILED; } diff --git a/mindspore/context.py b/mindspore/context.py index 44daf6d464..ae677954b4 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -393,7 +393,7 @@ def set_auto_parallel_context(**kwargs): pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how the devices are distributed alone the pipeline. The total devices will be divided into 'pipeline_stags' stages. This currently could only be used when - parall mode semi_auto_parallel is enabled. + parallel mode semi_auto_parallel is enabled. Default: 1. Raises: ValueError: If input key is not attribute in auto parallel context. @@ -446,6 +446,7 @@ def reset_auto_parallel_context(): - strategy_ckpt_save_file: ''. - full_batch: False. - enable_parallel_optimizer: False. + - pipeline_stages: 1. """ _reset_auto_parallel_context() diff --git a/tests/ut/python/parallel/test_set_auto_parallel_context.py b/tests/ut/python/parallel/test_set_auto_parallel_context.py index 534a333940..5f879064b7 100644 --- a/tests/ut/python/parallel/test_set_auto_parallel_context.py +++ b/tests/ut/python/parallel/test_set_auto_parallel_context.py @@ -107,4 +107,4 @@ def test_reset_auto_parallel_context(): assert not parameter_broadcast assert not device_num_is_set assert not parameter_broadcast_is_set - assert not stage + assert stage == 1