diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.h b/mindspore/ccsrc/frontend/parallel/device_manager.h index 1fdcb2f5aa..21d776a067 100644 --- a/mindspore/ccsrc/frontend/parallel/device_manager.h +++ b/mindspore/ccsrc/frontend/parallel/device_manager.h @@ -75,6 +75,7 @@ class DeviceManager { size_t DeviceNum() const { return devices_.size(); } int64_t stage_num() const { return stage_num_; } + int64_t stage_device_num() const { return stage_device_num_; } int64_t stage_id() const { return stage_id_; } int64_t rank_index_in_stage() const { return rank_index_in_stage_; } int64_t global_rank() const { return global_rank_; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc index 9eb4a4ea90..b457c6d3d7 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc @@ -41,11 +41,9 @@ Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) { } // dropout don't support repeated calculation - CheckGlobalDeviceManager(); auto input_strategy = strategy->GetInputDim().at(0); - size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies()); - if (IntToSize(product_p) != dev_num) { + if (product_p != stage_device_size_) { MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; return FAILED; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc index d0d88f65df..de4617c83c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc @@ -32,11 +32,6 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } - int64_t stage = strategy->GetInputStage(); - CheckGlobalDeviceManager(); - int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size()); - dev_num_ = dev_num; - size_t strategy_size = strategy->GetInputNumber(); Strategys stra = strategy->GetInputDim(); for (size_t i = 0; i < strategy_size; ++i) { @@ -46,7 +41,7 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { for (size_t j = 0; j < strategy_len; ++j) { int64_t strategy_value = sub_strategy.at(j); if (strategy_value > 1) { - if (flag || strategy_value != dev_num_) { + if (flag || strategy_value != stage_device_size_) { MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; return FAILED; } @@ -58,7 +53,7 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { } Status BatchParallelInfo::InferDevMatrixShape() { - dev_matrix_shape_.push_back(dev_num_); + dev_matrix_shape_.push_back(stage_device_size_); return SUCCESS; } @@ -81,14 +76,14 @@ Status BatchParallelInfo::InferMirrorOps() { Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } Status BatchParallelInfo::InferTensorMap() { - if (strategy_->GetInputDim()[0][0] != dev_num_) { + if (strategy_->GetInputDim()[0][0] != stage_device_size_) { MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; return FAILED; } for (size_t i = 0; i < inputs_shape_.size(); i++) { Shape tensor_map_index; for (size_t j = 0; j < inputs_shape_[i].size(); ++j) { - if (strategy_->GetInputDim()[i][j] == dev_num_ && j == 0) { + if (strategy_->GetInputDim()[i][j] == stage_device_size_ && j == 0) { tensor_map_index.push_back(0); } else { tensor_map_index.push_back(MAP_NONE); @@ -117,7 +112,7 @@ Strategys BatchParallelInfo::GetOutputsStrategy() { Dimensions strategy; for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { if (i == 0 && j == 0) { - strategy.push_back(dev_num_); + strategy.push_back(stage_device_size_); } else { strategy.push_back(1); } @@ -176,14 +171,12 @@ Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { } Status BatchParallelInfo::GenerateStrategies(int64_t stage_id) { - CheckGlobalDeviceManager(); - size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); StrategyPtr sp; Strategys strategy; for (size_t i = 0; i < inputs_shape_.size(); i++) { Shape temp(inputs_shape_[i].size(), 1); if (split_flag_list_[i]) { - temp[0] = SizeToLong(total_dev_num); + temp[0] = stage_device_size_; } strategy.push_back(temp); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc index 04a9f241a8..6d0d84178a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc @@ -151,10 +151,8 @@ Status DropoutDoMaskInfo::GenerateStrategies(int64_t stage_id) { } std::shared_ptr DropoutDoMaskInfo::GenerateBatchStrategies() { - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); Dimensions strategy(inputs_shape_[0].size() - 1, 1); - (void)strategy.insert(strategy.begin(), SizeToLong(dev_num)); + (void)strategy.insert(strategy.begin(), stage_device_size_); Strategys strategy_v = {strategy}; return std::make_shared(strategy_v); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc index 6568c54c24..4adf6c1973 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc @@ -308,8 +308,6 @@ std::shared_ptr GatherV2Info::GenerateBatchStrategies() { MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " << inputs_shape_.size(); } - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); if (GetAttrs() != SUCCESS) { MS_LOG(EXCEPTION) << "GetAttrs failed!"; } @@ -318,7 +316,7 @@ std::shared_ptr GatherV2Info::GenerateBatchStrategies() { if (index_size_ != 1) { strategy.push_back(1); } else { - strategy.push_back(SizeToLong(dev_num)); + strategy.push_back(stage_device_size_); } for (size_t i = 1; i < inputs_shape_[0].size(); i++) { strategy.push_back(1); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index 3102327572..444134b177 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -199,10 +199,8 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) { } // Don't support repeated calc - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); - if (IntToSize(product_p) < dev_num) { + if (product_p < stage_device_size_) { MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc"; return FAILED; } @@ -272,10 +270,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { } // param_strategy(axis) != 1, Don't support repeated calc - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); - if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { + if (product_p != stage_device_size_ && param_strategy.at(IntToSize(axis_)) != 1) { MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; return FAILED; } @@ -349,13 +345,11 @@ Status GatherV2PInfo::InferDevMatrixShape() { } else { out_dev_matrix_shape_ = dev_matrix_shape_; } - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); - if (param_product * index_product < SizeToInt(dev_num)) { + if (param_product * index_product < stage_device_size_) { // add the repeated calculation num to the last dimension of dev matrix - out_dev_matrix_shape_.push_back(SizeToInt(dev_num / (param_product * index_product))); + out_dev_matrix_shape_.push_back(stage_device_size_ / (param_product * index_product)); } return SUCCESS; @@ -539,11 +533,8 @@ Status GatherV2PInfo::InferGroup() { dim = (axis_ + 1) % 2; } - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id_); int64_t rank = g_device_manager->global_rank(); - DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_); + DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); RankList group_devices; if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Create group failed."; @@ -777,11 +768,10 @@ std::shared_ptr GatherV2PInfo::GenerateBatchStrategies() { if (manual_split_) { MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy"; } - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + Dimensions param_strategy(inputs_shape_[0].size(), 1); Dimensions index_strategy; - index_strategy.push_back(SizeToLong(dev_num)); + index_strategy.push_back(stage_device_size_); for (size_t i = 1; i < inputs_shape_[1].size(); i++) { index_strategy.push_back(1); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc index b85e63c11e..b836a339d0 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc @@ -66,7 +66,7 @@ Strategys GetNextInfo::GetOutputStrategy() { Strategys outputs_strategy; for (auto shp : shapes_) { Dimensions out_strategy; - out_strategy.push_back(dev_num_); + out_strategy.push_back(stage_device_size_); for (size_t i = 1; i < shp.size(); ++i) { out_strategy.push_back(1); } @@ -97,7 +97,7 @@ Status GetNextInfo::InferDevMatrixShape() { if (max_shape_length == 0) { MS_LOG(ERROR) << name_ << " : shape is 0"; } - dev_matrix_shape_.push_back(dev_num_); + dev_matrix_shape_.push_back(stage_device_size_); for (size_t i = 1; i < max_shape_length; ++i) { dev_matrix_shape_.push_back(1); } @@ -125,9 +125,6 @@ Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } } - int64_t stage = strategy->GetInputStage(); - int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size()); - dev_num_ = dev_num; return SUCCESS; } @@ -199,16 +196,16 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { Shapes out_shapes = outputs_shape_; for (size_t i = 0; i < out_shapes.size(); ++i) { - if (dev_num_ <= 0) { + if (stage_device_size_ <= 0) { MS_LOG(ERROR) << name_ << " : The dev num is 0."; return FAILED; } if (!full_batch) { - if (out_shapes[i][0] % dev_num_ != 0) { + if (out_shapes[i][0] % stage_device_size_ != 0) { MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; return FAILED; } - out_shapes[i][0] = out_shapes[i][0] / dev_num_; + out_shapes[i][0] = out_shapes[i][0] / stage_device_size_; } } ValuePtr new_shapes = MakeValue(out_shapes); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc index 3516222a65..6829766057 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc @@ -601,10 +601,8 @@ Status MatMulBase::CheckForTensorSliceValid() const { } std::shared_ptr BatchMatMulInfo::GenerateBatchStrategies() { - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1); - batch_strategy.insert(batch_strategy.begin(), SizeToLong(dev_num)); + batch_strategy.insert(batch_strategy.begin(), stage_device_size_); Strategys strategy_v = {batch_strategy, batch_strategy}; return std::make_shared(strategy_v); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc index 457cd84420..8a700fb66a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc @@ -268,9 +268,7 @@ Status OneHotInfo::GenerateStrategies(int64_t stage_id) { Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } std::shared_ptr OneHotInfo::GenerateBatchStrategies() { - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - Dimensions strategy = {SizeToLong(dev_num), 1}; + Dimensions strategy = {stage_device_size_, 1}; Dimensions empty_strategy; Strategys strategy_v = {strategy, empty_strategy, empty_strategy}; return std::make_shared(strategy_v); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index 4cd500a71c..89003c3aac 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -688,7 +688,7 @@ std::shared_ptr GenerateBatchStrategiesBySplitFlag(const Shapes &shap return nullptr; } CheckGlobalDeviceManager(); - int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); + int64_t dev_num = g_device_manager->stage_device_num(); Strategys strategy_v; for (size_t i = 0; i != shapes.size(); i++) { if (shapes[i].empty()) { @@ -1393,9 +1393,7 @@ Status OperatorInfo::set_outputs_type(const std::vector &outputs_type) void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) { if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) { - CheckGlobalDeviceManager(); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); - if (LongToSize(stra->GetInputDim()[0][0]) == total_device_num) { + if (stra->GetInputDim()[0][0] == stage_device_size_) { if (cost->computation_cost_ > 1.0) { cost->computation_cost_ -= 1.0; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc index 26acd68d6c..26f01b7904 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc @@ -233,15 +233,13 @@ std::shared_ptr SplitInfo::GenerateBatchStrategies() { if (GetAttrs() != SUCCESS) { MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; } - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); Dimensions input_strategy(inputs_shape_[0].size(), 1); // axis can't split if (inputs_shape_[0].size() > 1) { if (axis_ == 0) { - input_strategy[1] = dev_num; + input_strategy[1] = stage_device_size_; } else { - input_strategy[0] = dev_num; + input_strategy[0] = stage_device_size_; } } Strategys strategy_v = {input_strategy}; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tensordot_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/tensordot_info.cc index 3bdf626cf8..b08d023f63 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/tensordot_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tensordot_info.cc @@ -408,17 +408,14 @@ std::shared_ptr TensorDotInfo::GenerateBatchStrategies() { if (GetAttrs() != SUCCESS) { MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; } - - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); Dimensions input_a_strategy(inputs_shape_[0].size(), 1); Dimensions input_b_strategy(inputs_shape_[1].size(), 1); - input_a_strategy[0] = SizeToInt(dev_num); + input_a_strategy[0] = stage_device_size_; if (axes_type_ == INT_TYPE) { if (IntToSize(axes_int_) == inputs_shape_[0].size()) { - input_b_strategy[0] = SizeToInt(dev_num); // find the relavent dimension for input_b + input_b_strategy[0] = stage_device_size_; // find the relavent dimension for input_b } } else if (axes_type_ == TUPLE_TUPLE_TYPE) { // if the input_a's axes contain 0, the input_b has the relavent dimension with batch dimension @@ -434,7 +431,7 @@ std::shared_ptr TensorDotInfo::GenerateBatchStrategies() { if (found) { // find the relavant - input_b_strategy[axes_tuple_tuple_[1][relavant_index]] = dev_num; + input_b_strategy[axes_tuple_tuple_[1][relavant_index]] = stage_device_size_; } } else { MS_LOG(EXCEPTION) << name_ << ": Now do not support TUPLE_TYPE"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc index 118584759f..8b70389d45 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc @@ -85,7 +85,7 @@ Status UniqueInfo::InferTensorInfo() { } Status UniqueInfo::InferDevMatrixShape() { - dev_matrix_shape_.push_back(dev_num_); + dev_matrix_shape_.push_back(stage_device_size_); return SUCCESS; } @@ -110,9 +110,7 @@ Status UniqueInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } } - int64_t stage = strategy->GetInputStage(); - int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size()); - dev_num_ = dev_num; + if (stras[0][0] != 1) { MS_LOG(ERROR) << "Currently, unique only support repeat calculate in all devices"; return FAILED; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc index 12288c3f2c..f8316307f9 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc @@ -277,20 +277,17 @@ std::shared_ptr UnsortedSegmentOpInfo::GenerateBatchStrategies() { MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is " << inputs_shape_.size(); } - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); if (GetAttrs() != SUCCESS) { MS_LOG(EXCEPTION) << "GetAttrs failed!"; } - Dimensions strategy_a; - Dimensions strategy_b; - strategy_a.push_back(SizeToInt(dev_num)); + Dimensions strategy_a, strategy_b; + strategy_a.push_back(stage_device_size_); for (size_t i = 1; i < inputs_shape_[0].size(); i++) { strategy_a.push_back(1); } - strategy_b.push_back(SizeToInt(dev_num)); + strategy_b.push_back(stage_device_size_); for (size_t i = 1; i < inputs_shape_[1].size(); i++) { strategy_b.push_back(1); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc index 871758409f..c8fa56f5d4 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc @@ -66,13 +66,10 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { Status VirtualDatasetInfo::InferDevMatrixShape() { Strategys stra = strategy_->GetInputDim(); Dimensions strategy_first = stra.at(0); - int64_t stage = strategy_->GetInputStage(); - CheckGlobalDeviceManager(); - int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size()); int64_t batch_split_num = ((int64_t)(strategy_first.at(0))); dev_matrix_shape_.push_back(batch_split_num); - if (dev_num > batch_split_num) { - dev_matrix_shape_.push_back(dev_num / batch_split_num); + if (stage_device_size_ > batch_split_num) { + dev_matrix_shape_.push_back(stage_device_size_ / batch_split_num); } return SUCCESS; @@ -156,11 +153,10 @@ Status VirtualDatasetInfo::GenerateStrategies(int64_t stage_id) { return FAILED; } - CheckGlobalDeviceManager(); if (full_batch) { total_dev_num = 1; } else { - total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + total_dev_num = stage_device_size_; } StrategyPtr sp; Strategys strategy; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index a5d28584d5..a3a76b40b1 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1640,7 +1640,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { if (full_batch) { dev_num = 1; } else { - dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); + dev_num = SizeToLong(g_device_manager->stage_device_num()); } auto attrs_temp = prim->attrs(); std::vector shape_list = ExtractShape(node); @@ -1984,7 +1984,7 @@ std::shared_ptr CreateParameterLayout(const AnfNodePtr &node) { return next_layout; } CheckGlobalDeviceManager(); - int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); + int64_t dev_num = g_device_manager->stage_device_num(); TensorLayout input_tensor_layout; // create input_shape Shapes inputs_shape = GetNodeShape(node); @@ -2009,7 +2009,7 @@ RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const Te TensorRedistribution tensor_redistribution; // create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num]. CheckGlobalDeviceManager(); - int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); + int64_t dev_num = g_device_manager->stage_device_num(); TensorLayout stand_alone_layout; Shapes inputs_shape = GetNodeShape(node); if (inputs_shape.empty()) { @@ -2029,7 +2029,7 @@ RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const Te } // Infer Redistribution op list for stand alone and loss layout. - RankList dev_list = g_device_manager->GetDeviceListByStageId(0); + RankList dev_list = g_device_manager->GetDeviceListInThisStage(); if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) { MS_LOG(EXCEPTION) << "Redistribution for Sens init failed."; } @@ -3093,7 +3093,7 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) { if (full_batch) { return; } - auto dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + auto dev_num = g_device_manager->stage_device_num(); auto parameters = root->parameters(); for (auto ¶meter : parameters) { if (IsUsedParameter(root, parameter)) {