diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index 732441461b..3a2084a6e2 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -63,6 +63,8 @@ void ParallelContext::Reset() { all_reduce_fusion_split_indices_.clear(); all_reduce_fusion_split_sizes_.clear(); strategy_search_mode_ = DYNAMIC_PROGRAMMING; + stages_.clear(); + pipeline_stage_split_num_ = 0; } void ParallelContext::set_device_num(int32_t device_num) { @@ -83,6 +85,10 @@ void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } +void ParallelContext::set_pipeline_stage_split_num(const int32_t stage_num) { pipeline_stage_split_num_ = stage_num; } + +void ParallelContext::set_stage(const std::vector &stages) { stages_ = stages; } + bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); if (iter == PARALLEL_MODE_LIST.end()) { diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index 3f55f9a152..f1981aac33 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -67,6 +67,12 @@ class ParallelContext { void set_device_num(int32_t device_num); int32_t device_num() const { return device_num_; } + void set_pipeline_stage_split_num(const int32_t stages); + int32_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; } + + void set_stage(const std::vector &stages); + std::vector stage() const { return stages_; } + void set_global_rank(int32_t global_rank); int32_t global_rank() const { return global_rank_; } @@ -115,6 +121,8 @@ class ParallelContext { int32_t global_rank_; std::string parallel_mode_; std::string strategy_search_mode_; + std::vector stages_; + int32_t pipeline_stage_split_num_; bool parameter_broadcast_; bool device_num_is_set_; bool global_rank_is_set_; diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.cc b/mindspore/ccsrc/frontend/parallel/device_manager.cc index a272702cd1..5202e038c8 100644 --- a/mindspore/ccsrc/frontend/parallel/device_manager.cc +++ b/mindspore/ccsrc/frontend/parallel/device_manager.cc @@ -36,7 +36,8 @@ Stage::Stage(const std::vector &devices, int num, i // NOTE: '-1' indicates ERROR int Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); } -bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend) { +bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend, + const std::vector &stage) { if (device_num <= 0) { MS_LOG(ERROR) << "'device_num' must be positive."; return false; @@ -68,7 +69,30 @@ bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &back devices.push_back(i); } - stage_map.push_back(device_num); + if (stage.size()) { + int32_t summed_value = 0; + for (auto begin = stage.begin(); begin != stage.end(); ++begin) { + if (*begin <= 0) { + MS_LOG(ERROR) << "The value in the pipeline stages should be positive value"; + return false; + } + summed_value += *begin; + stage_map.push_back(*begin); + } + + if (summed_value != device_num) { + MS_LOG(ERROR) << "The sum of the pipeline stage :" << summed_value << " is not equal to the device_num " + << device_num; + return false; + } + } else { + stage_map.push_back(device_num); + } + + for (auto &y : stage_map) { + MS_LOG(DEBUG) << "Obtained stage id :" << y; + } + g_device_manager = std::make_shared(); if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) { MS_LOG(INFO) << "Device initialization succeeds."; diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.h b/mindspore/ccsrc/frontend/parallel/device_manager.h index 3023f4a355..79e0487c10 100644 --- a/mindspore/ccsrc/frontend/parallel/device_manager.h +++ b/mindspore/ccsrc/frontend/parallel/device_manager.h @@ -70,7 +70,7 @@ class Stage { // This method is used for initializing the global DeviceManager 'g_device_manager', // arguments including 'device_num' and 'global_rank' -bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend); +bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend, const std::vector &stage); void CheckGlobalDeviceManager(); diff --git a/mindspore/ccsrc/frontend/parallel/device_matrix.cc b/mindspore/ccsrc/frontend/parallel/device_matrix.cc index e54f6d84ee..8533b86670 100644 --- a/mindspore/ccsrc/frontend/parallel/device_matrix.cc +++ b/mindspore/ccsrc/frontend/parallel/device_matrix.cc @@ -126,9 +126,22 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *ra } } - Shape current_rank_coordinate = ConvertRankToCoordinate(rank_, dev_shape_); + // Convert the global rank to the local rank(The index of the array) to compute the coordinate + uint32_t local_rank = 0; for (auto &tmp_rank : dev_list_) { - Shape tmp_rank_coordinate = ConvertRankToCoordinate(tmp_rank, dev_shape_); + if (tmp_rank == rank_) { + break; + } + ++local_rank; + } + if (local_rank == dev_list_.size()) { + MS_LOG(ERROR) << "Rank id: " << local_rank << "is not in the device list."; + return FAILED; + } + + Shape current_rank_coordinate = ConvertRankToCoordinate((int32_t)local_rank, dev_shape_); + for (uint32_t loop_local_rank = 0; loop_local_rank < dev_list_.size(); ++loop_local_rank) { + Shape tmp_rank_coordinate = ConvertRankToCoordinate(loop_local_rank, dev_shape_); bool matched = true; for (auto &map : tensor_map) { if (map == MAP_NONE) { @@ -141,7 +154,7 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *ra } } if (matched) { - rank_list->push_back(tmp_rank); + rank_list->push_back(dev_list_[loop_local_rank]); } } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc index b6bd16fae0..e4da62f666 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc @@ -43,7 +43,7 @@ Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) { // dropout don't support repeated calculation CheckGlobalDeviceManager(); auto input_strategy = strategy->GetInputDim().at(0); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies()); if (IntToSize(product_p) != dev_num) { MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index 64607cd7b8..313616fb30 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -196,7 +196,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) { // Don't support repeated calc CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); if (IntToSize(product_p) < dev_num) { MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc"; @@ -269,7 +269,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { // param_strategy(axis) != 1, Don't support repeated calc CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; @@ -346,7 +346,7 @@ Status GatherV2PInfo::InferDevMatrixShape() { out_dev_matrix_shape_ = dev_matrix_shape_; } CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); if (param_product * index_product < SizeToInt(dev_num)) { @@ -516,10 +516,11 @@ Status GatherV2PInfo::InferGroup() { if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) { dim = (axis_ + 1) % 2; } + CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); + RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id_); int32_t rank = g_device_manager->global_rank(); - RankList dev_list = g_device_manager->GetDeviceListByStageId(0); DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_); RankList group_devices; if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index ce5dc59131..ce902be2d7 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -162,7 +162,8 @@ class OperatorInfo { void set_type(const std::string &type) { type_ = type; } const std::string &type() const { return type_; } const std::unordered_map &attrs() const { return attrs_; } - + void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } + int32_t stage_id() const { return stage_id_; } // Key for user data. constexpr static char key[] = "OpInfo"; @@ -205,6 +206,7 @@ class OperatorInfo { std::vector input_value_; TypePtr outputs_dtype_; + int32_t stage_id_ = 0; StrategyPtr strategy_; std::vector inputs_tensor_info_; std::vector outputs_tensor_info_; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 6729dceb0a..53f36278a2 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -55,6 +55,7 @@ constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only"; constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only"; constexpr char CHECK_SET_STRATEGY_VALID_ONCE_ONLY[] = "check_set_strategy_valid_once_only"; constexpr char STRATEGY[] = "strategy"; +constexpr char STAGE_ATTR[] = "stage"; constexpr char GEN_STRATEGY[] = "gen_strategy"; constexpr char REDUCE_OP_SUM[] = "sum"; constexpr char REDUCE_OP_MAX[] = "max"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc index d8982b2176..6f6e75b53f 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc @@ -133,9 +133,9 @@ Status ReduceMethod::InferTensorMap() { return SUCCESS; } -bool IsDataParallelStrategy(const Dimensions &strategy) { +bool IsDataParallelStrategy(const Dimensions &strategy, int32_t stage_id) { CheckGlobalDeviceManager(); - size_t total_dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); if (strategy.empty()) { MS_LOG(EXCEPTION) << "IsDataParallelStrategy: strategy is empty"; } @@ -145,7 +145,7 @@ bool IsDataParallelStrategy(const Dimensions &strategy) { Status ReduceMethod::InferForwardCommunication() { Dimensions stra = strategy_->GetInputDim().at(0); - if (cross_batch_ && IsDataParallelStrategy(stra)) { + if (cross_batch_ && IsDataParallelStrategy(stra, stage_id_)) { MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; return SUCCESS; } @@ -211,7 +211,7 @@ ForwardOp CreatReduceMeanForwardOp(const std::vector &forward_group, cons Status ReduceMeanInfo::InferForwardCommunication() { Dimensions stra = strategy_->GetInputDim().at(0); - if (cross_batch_ && IsDataParallelStrategy(stra)) { + if (cross_batch_ && IsDataParallelStrategy(stra, stage_id_)) { MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 1189e72165..d64fcac7f5 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -998,6 +998,17 @@ OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAtt StrategyPtr ExtractStrategy(std::unordered_map attrs) { ValueTuplePtr var = attrs[STRATEGY]->cast(); StrategyPtr strategyPtr; + std::vector stages = ParallelContext::GetInstance()->stage(); + auto res = attrs.find(STAGE_ATTR); + int32_t stage_id = 0; + if (res != attrs.end()) { + stage_id = GetValue(res->second); + } + if (stage_id && stages.empty()) { + MS_LOG(ERROR) << "Find stage id:" << stage_id << " but the pipeline_stages is 0."; + return nullptr; + } + MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString(); if (var == nullptr) { MS_LOG(EXCEPTION) << "Strategy value is nullptr"; @@ -1016,13 +1027,13 @@ StrategyPtr ExtractStrategy(std::unordered_map attrs) { }); strategy.push_back(dim); } else { - MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue"; + MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequence"; } } if (strategy.empty()) { MS_LOG(EXCEPTION) << "ExtractStrategy:failed to extract strategy"; } - strategyPtr = NewStrategy(0, strategy); + strategyPtr = NewStrategy(stage_id, strategy); } return strategyPtr; @@ -1420,6 +1431,30 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { (void)prim->SetAttrs(attrs_temp); } } +// This function aims to check the valid rank and stage in the operations +// If the rank is not valid for the given stage, we chose not to init the strategy of the operation +// For example stage is [4, 4], and the group_list [[0,1,2,3],[4,5,6,7]] +// For stage 0, we require the rank_id is in [0,1,2,3] +Status ValidRankCheck(int32_t global_rank, int32_t strategy_stage) { + RankList local_group_list = g_device_manager->GetDeviceListByStageId(strategy_stage); + int32_t target = global_rank; + if (std::any_of(local_group_list.begin(), local_group_list.end(), [target](int32_t a) { return a == target; })) { + return Status::SUCCESS; + } + + return Status::FAILED; +} + +Status ValidStageCheck(const std::vector &stages, int32_t strategy_stage) { + if (stages.size() > 0) { + if (strategy_stage >= 0 && strategy_stage < (int32_t)stages.size()) { + return Status::SUCCESS; + } + return Status::FAILED; + } else { + return Status::SUCCESS; + } +} void ExtractInformation(const std::vector &all_nodes) { // load strategy map from checkpoint @@ -1429,6 +1464,11 @@ void ExtractInformation(const std::vector &all_nodes) { MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; } } + + // Get global rank after the checkpoint? + int32_t global_rank = ParallelContext::GetInstance()->global_rank(); + std::vector stages = ParallelContext::GetInstance()->stage(); + for (auto &node : all_nodes) { auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { @@ -1501,7 +1541,18 @@ void ExtractInformation(const std::vector &all_nodes) { strategyPtr = ExtractStrategy(attrs); } if (strategyPtr != nullptr) { - if (operator_->Init(strategyPtr) == FAILED) { + (*operator_).set_stage_id(strategyPtr->GetInputStage()); + MS_LOG(INFO) << "Extract stage id for op " << prim->name() << " is " << (*operator_).stage_id(); + if (ValidStageCheck(stages, (*operator_).stage_id()) == FAILED) { + MS_LOG(ERROR) << "Find stage " << strategyPtr->GetInputStage() << " for operator " << prim->name() + << " exceeds the global stage size " << stages.size() << '.'; + return; + } + // If the strategy is not valid for the given global rank, then we skip the Init of the strategy + if (ValidRankCheck(global_rank, (*operator_).stage_id()) == FAILED) { + MS_LOG(INFO) << "Find global exceeds the range of the stage, skip the strategy init for operator " + << prim->name(); + } else if (operator_->Init(strategyPtr) == FAILED) { MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; } cnode->set_user_data(operator_); @@ -2416,6 +2467,9 @@ Status ParallelInit() { MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); int32_t device_num = ParallelContext::GetInstance()->device_num(); int32_t global_rank = ParallelContext::GetInstance()->global_rank(); + int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num(); + std::vector stages = ParallelContext::GetInstance()->stage(); + std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); std::string backend = ms_context->get_param(MS_CTX_DEVICE_TARGET); @@ -2431,6 +2485,26 @@ Status ParallelInit() { MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; } + if (device_num <= 0) { + MS_LOG(ERROR) << "Invalid device num " << device_num << " , expected a positive device number"; + return FAILED; + } + if (split_stage_num > 0) { + if (device_num % split_stage_num != 0) { + MS_LOG(ERROR) << "Device num " << device_num << " can't be divided by stage num " << split_stage_num + << " , as we support only extract devision now"; + return FAILED; + } + for (int i = 0; i < split_stage_num; i++) { + stages.push_back(device_num / split_stage_num); + } + } else if (split_stage_num < 0) { + MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << " , expected a positive stage number"; + return FAILED; + } + + ParallelContext::GetInstance()->set_stage(stages); + uint32_t world_rank_size = 0; if (!ParallelContext::GetInstance()->device_num_is_set()) { if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) { @@ -2449,7 +2523,12 @@ Status ParallelInit() { MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; } - if (!InitDevice(device_num, global_rank, communication_backend)) { + if (!stages.empty() && parallel_mode != SEMI_AUTO_PARALLEL) { + MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << SEMI_AUTO_PARALLEL; + return FAILED; + } + + if (!InitDevice(device_num, global_rank, communication_backend, stages)) { MS_LOG(ERROR) << "Init device failed"; return FAILED; } @@ -2457,6 +2536,7 @@ Status ParallelInit() { MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank << ", backend: " << backend << ", gradients_mean: " << ParallelContext::GetInstance()->gradients_mean() << ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync(); + return SUCCESS; } diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 697039243f..b5cbabe53d 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -152,6 +152,9 @@ PYBIND11_MODULE(_c_expression, m) { "Set strategy checkpoint save file.") .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") + .def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num, + "Set pipeline stage split num.") + .def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.") .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer, diff --git a/mindspore/context.py b/mindspore/context.py index aa4ee135a4..cd2ed2c1d8 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -331,7 +331,7 @@ def _context(): @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, - all_reduce_fusion_config=list) + all_reduce_fusion_config=list, pipeline_stages=int) def set_auto_parallel_context(**kwargs): """ Set auto parallel context. @@ -357,6 +357,7 @@ def set_auto_parallel_context(**kwargs): parallel_mode strategy_ckpt_load_file all_reduce_fusion_config strategy_ckpt_save_file full_batch + pipeline_stages =========================== =========================== ================= Args: @@ -399,6 +400,10 @@ def set_auto_parallel_context(**kwargs): the fusion is closed. all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed. + pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how + the devices are distributed alone the pipeline. The total devices will be divided into + 'pipeline_stags' stages. This currently could only be used when + parall mode semi_auto_parallel is enabled. Raises: ValueError: If input key is not attribute in auto parallel context. @@ -416,10 +421,10 @@ def set_auto_parallel_context(**kwargs): >>> context.set_auto_parallel_context(full_batch=True) >>> context.set_auto_parallel_context(enable_parallel_optimizer=False) >>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160]) + >>> context.set_auto_parallel_context(pipeline_stages=2) """ _set_auto_parallel_context(**kwargs) - def get_auto_parallel_context(attr_key): """ Gets auto parallel context attribute value according to the key. diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 3b217719e2..2a91024b0d 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -102,6 +102,20 @@ class Primitive(Primitive_): self.add_attr(name, value) return self + def set_stage(self, stage): + """ + Add stage id to primitive attribute. + + Note: + It is valid only in semi auto parallel. + In other parallel modes, please set it to be 0. + + Args: + stage (int): The stage id for the current operation + """ + self.add_prim_attr("stage", stage) + return self + def shard(self, strategy): """ Add strategies to primitive attribute. diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 91e3189d2e..01afb012fa 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -95,6 +95,16 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_global_rank() + def set_pipeline_stages(self, stages): + """Set the stages of the pipeline""" + self.check_context_handle() + self._context_handle.set_pipeline_stage_split_num(stages) + + def get_pipeline_stages(self): + """Get the stages of the pipeline""" + self.check_context_handle() + return self._context_handle.get_pipeline_stage_split_num() + def set_gradients_mean(self, gradients_mean): """ Set gradients_mean flag. @@ -466,6 +476,7 @@ _set_auto_parallel_context_func_map = { "gradients_mean": auto_parallel_context().set_gradients_mean, "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync, "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, + "pipeline_stages": auto_parallel_context().set_pipeline_stages, "parallel_mode": auto_parallel_context().set_parallel_mode, "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode, "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, @@ -482,6 +493,7 @@ _get_auto_parallel_context_func_map = { "gradients_mean": auto_parallel_context().get_gradients_mean, "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync, "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, + "pipeline_stages": auto_parallel_context().get_pipeline_stages, "parallel_mode": auto_parallel_context().get_parallel_mode, "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode, "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, @@ -569,7 +581,6 @@ def _get_auto_parallel_context(attr_key): get_func = _get_auto_parallel_context_func_map[attr_key] return get_func() - def _reset_auto_parallel_context(): """ Reset auto parallel context attributes to the default values: @@ -584,5 +595,6 @@ def _reset_auto_parallel_context(): - strategy_ckpt_save_file: "" - enable_parallel_optimizer: False - auto_parallel_search_mode: dynamic_programming + - pipeline_stages: 0 """ auto_parallel_context().reset() diff --git a/tests/ut/cpp/parallel/device_matrix_test.cc b/tests/ut/cpp/parallel/device_matrix_test.cc index 57a438e76e..2059f1b78e 100644 --- a/tests/ut/cpp/parallel/device_matrix_test.cc +++ b/tests/ut/cpp/parallel/device_matrix_test.cc @@ -83,6 +83,39 @@ TEST_F(TestDeviceMatrix, TestCornerCaseGetAlongDim) { EXPECT_THROW({ DeviceMatrix arr(3, dev_list, shape); }, std::runtime_error); } +TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapRandomOrderSliceOne) { + RankList dev_list = {10, 3, 2, 9, 11, 100, 1, 0}; + Shape tensor_map = {-1, 0}; + RankList rank_list; + Shape shape = {4, 2}; + DeviceMatrix arr(0, dev_list, shape); + arr.GetDevicesByTensorMap(tensor_map, &rank_list); + RankList rank_list_except = {3, 9, 100, 0}; + ASSERT_EQ(rank_list, rank_list_except); +} + +TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapRandomOrderSliceTwo) { + RankList dev_list = {10, 3, 2, 9, 11, 100, 1, 0}; + Shape tensor_map = {1, 0}; + RankList rank_list; + Shape shape = {4, 2}; + DeviceMatrix arr(0, dev_list, shape); + arr.GetDevicesByTensorMap(tensor_map, &rank_list); + RankList rank_list_except = {0}; + ASSERT_EQ(rank_list, rank_list_except); +} + +TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapNoramalOrder2D) { + RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7}; + Shape tensor_map = {-1, 0}; + RankList rank_list; + Shape shape = {4, 2}; + DeviceMatrix arr(6, dev_list, shape); + arr.GetDevicesByTensorMap(tensor_map, &rank_list); + RankList rank_list_except = {0, 2, 4, 6}; + ASSERT_EQ(rank_list, rank_list_except); +} + TEST_F(TestDeviceMatrix, TestCornerCase2GetAlongDim) { // Rank is out of range RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7}; diff --git a/tests/ut/python/parallel/test_pipeline_parallel.py b/tests/ut/python/parallel/test_pipeline_parallel.py new file mode 100644 index 0000000000..d20c77c98a --- /dev/null +++ b/tests/ut/python/parallel/test_pipeline_parallel.py @@ -0,0 +1,89 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +grad_all = C.GradOperation(get_all=True) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y): + return grad_all(self.network)(x, y) + + +class Net(nn.Cell): + def __init__(self, axis=0, stage1=0, stage2=0, strategy1=None, strategy2=None, shape=None, target=""): + super().__init__() + if shape is None: + shape = [64, 64] + self.gatherv2 = P.GatherV2().shard(strategy1).add_prim_attr("primitive_target", target) + self.mul = P.Mul().shard(strategy2) + self.index = Tensor(np.ones(shape), dtype=ms.int32) + self.gatherv2.set_stage(stage1) + self.mul.set_stage(stage2) + self.axis = axis + + def construct(self, x, y): + out = self.gatherv2(x, self.index, self.axis) + out = self.mul(out, y) + return out + + +def test_gatherv2_semi_samestage1(): + context.set_auto_parallel_context(device_num=8, global_rank=0, \ + parallel_mode="semi_auto_parallel", pipeline_stages=2) + strategy1 = ((1, 2), (1, 1)) + strategy2 = ((2, 1, 1), (2, 1, 1)) + net = GradWrap(NetWithLoss(Net(0, 0, 0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + +def test_gatherv2_semi_samestage2(): + context.set_auto_parallel_context(device_num=8, global_rank=5, \ + parallel_mode="semi_auto_parallel", pipeline_stages=2) + strategy1 = ((1, 2), (1, 1)) + strategy2 = ((2, 1, 1), (2, 1, 1)) + net = GradWrap(NetWithLoss(Net(0, 1, 1, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) diff --git a/tests/ut/python/parallel/test_set_auto_parallel_context.py b/tests/ut/python/parallel/test_set_auto_parallel_context.py index 5695784740..534a333940 100644 --- a/tests/ut/python/parallel/test_set_auto_parallel_context.py +++ b/tests/ut/python/parallel/test_set_auto_parallel_context.py @@ -81,6 +81,11 @@ def test_set_auto_parallel_context(): assert context.get_auto_parallel_context("enable_parallel_optimizer") assert not auto_parallel_context().get_all_reduce_fusion_split_indices() +def test_pipeline_parallel_context(): + context.set_auto_parallel_context(device_num=8, global_rank=4, + parallel_mode="semi_auto_parallel", pipeline_stages=2) + stage = auto_parallel_context().get_pipeline_stages() + assert stage == 2 def test_reset_auto_parallel_context(): context.reset_auto_parallel_context() @@ -92,6 +97,8 @@ def test_reset_auto_parallel_context(): parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") device_num_is_set = auto_parallel_context().get_device_num_is_set() parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() + stage = auto_parallel_context().get_pipeline_stages() + assert device_num == 1 assert global_rank == 0 assert not gradients_mean @@ -100,3 +107,4 @@ def test_reset_auto_parallel_context(): assert not parameter_broadcast assert not device_num_is_set assert not parameter_broadcast_is_set + assert not stage