diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc index 4906034964..f33353ebce 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc @@ -113,7 +113,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co size_t type_length, TypePtr type, CostPtr *cost) { MS_EXCEPTION_IF_NULL(prev_op_); MS_EXCEPTION_IF_NULL(cost); - RankList dev_list = prev_op_->global_device_list(); + RankList dev_list = prev_op_->stage_device_list(); TensorRedistribution tensor_redistribution(false); // Init TensorRedistribution diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.cc b/mindspore/ccsrc/frontend/parallel/device_manager.cc index f2b9043c89..5ec60d8f79 100644 --- a/mindspore/ccsrc/frontend/parallel/device_manager.cc +++ b/mindspore/ccsrc/frontend/parallel/device_manager.cc @@ -140,7 +140,12 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, const std::string &backend) { if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) { MS_LOG(ERROR) << "Invalid backend: " << backend; - return Status::FAILED; + return FAILED; + } + + if (stage_map.empty() || devices.empty()) { + MS_LOG(ERROR) << "The size of stage_map and devices must be positive"; + return FAILED; } for (auto &dev : devices) { @@ -153,11 +158,11 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, int64_t num_device = stage; if (num_device > MAX_DEVICE_NUM) { MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM; - return Status::FAILED; + return FAILED; } if (num_device <= 0) { MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive"; - return Status::FAILED; + return FAILED; } RankList curr_dev_list; for (int64_t i = 0; i < num_device; ++i) { @@ -170,10 +175,11 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, std::shared_ptr dev = std::make_shared(global_device_rank); device_ = dev; - set_global_rank(global_device_rank); - set_stage_num(static_cast(stage_map.size())); - int64_t stage_id = global_device_rank / static_cast(devices.size() / stage_map.size()); - set_stage_id(stage_id); + global_rank_ = global_device_rank; + stage_num_ = static_cast(stage_map.size()); + stage_id_ = global_device_rank / static_cast(devices.size() / stage_map.size()); + rank_index_in_stage_ = global_rank_ - stage_id_ * (static_cast(devices.size()) / stage_num_); + stage_device_num_ = static_cast(devices.size()) / stage_num_; backend_ = backend; @@ -185,10 +191,13 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, gm_.set_world_group(UNDEFINED_WORLD_GROUP); } MS_LOG(INFO) << "The device num: " << devices.size() << ", rank id: " << global_device_rank - << ", the backend: " << backend << ", the stage num: " << stage_num() << ", the stage id: " << stage_id; - return Status::SUCCESS; + << ", the backend: " << backend << ", the stage num: " << stage_num_ << ", the stage id: " << stage_id_ + << ", the rank index in stage is: " << rank_index_in_stage_; + return SUCCESS; } +RankList DeviceManager::GetDeviceListInThisStage() const { return GetDeviceListByStageId(stage_id_); } + RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const { if (LongToSize(stage_id) >= stage_devices_.size()) MS_LOG(ERROR) << "the 'stage_id': " << stage_id @@ -204,49 +213,6 @@ RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const { return res; } -RankList DeviceManager::global_device_list(int64_t stage_id, int64_t rank, int64_t split_num) const { - RankList res; - if (split_num <= 0) { - return res; - } - if (LongToSize(stage_id) >= stage_devices_.size()) { - MS_LOG(ERROR) << "the 'stage_id': " << stage_id - << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); - return res; - } - - RankList global_list = GetDeviceListByStageId(stage_id); - if (global_list.size() % LongToSize(split_num)) { - MS_LOG(ERROR) << "dev list size(" << global_list.size() << ") can not be divisible by split num: " << stage_id; - return res; - } - - std::vector dev_list; - (void)std::copy(global_list.begin(), global_list.end(), std::back_inserter(dev_list)); - - size_t index = 0; - size_t slice_size = dev_list.size() / LongToSize(split_num); - for (int64_t i = 0; i < split_num; ++i) { - bool found = false; - index = slice_size * LongToSize(i); - for (size_t j = 0; j < slice_size; ++j) { - if (dev_list[index + j] == rank) { - found = true; - break; - } - } - - if (found) { - break; - } - } - - for (size_t k = 0; k < slice_size; ++k) { - res.push_back(dev_list[index + k]); - } - return res; -} - Device DeviceManager::CreateNewDeviceByRank(int64_t rank) const { return Device(rank); } std::vector DeviceManager::CreateDeviceListByRankList(RankList ranks) { diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.h b/mindspore/ccsrc/frontend/parallel/device_manager.h index 723142a551..1fdcb2f5aa 100644 --- a/mindspore/ccsrc/frontend/parallel/device_manager.h +++ b/mindspore/ccsrc/frontend/parallel/device_manager.h @@ -57,14 +57,14 @@ std::string HashName(const std::string &rank_list_name); class DeviceManager { // This class is used to manage the abstract devices, including group-related and stage-related management. public: - DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(1), stage_id_(0) { gm_ = GroupManager(); } + DeviceManager() { gm_ = GroupManager(); } ~DeviceManager() = default; Status Init(const RankList &devices, int64_t local_device, const RankList &stage_map, const std::string &backend); static DeviceManager &GetInstance(); RankList GetDeviceListByStageId(int64_t stage_id) const; - RankList global_device_list(int64_t stage_id, int64_t rank, int64_t split_num) const; + RankList GetDeviceListInThisStage() const; Device CreateNewDeviceByRank(int64_t rank) const; std::vector CreateDeviceListByRankList(RankList ranks); @@ -74,17 +74,11 @@ class DeviceManager { Group CreateGroup(const RankList &dev_ranks); size_t DeviceNum() const { return devices_.size(); } - int64_t stage_num() const { return stage_num_; } - void set_stage_num(int64_t num) { stage_num_ = num; } - int64_t stage_id() const { return stage_id_; } - void set_stage_id(int64_t id) { stage_id_ = id; } - - std::string backend() const { return backend_; } - + int64_t rank_index_in_stage() const { return rank_index_in_stage_; } int64_t global_rank() const { return global_rank_; } - void set_global_rank(int64_t global_rank) { global_rank_ = global_rank; } + std::string backend() const { return backend_; } void Clear(); std::string world_group() const { return gm_.world_group(); } @@ -102,10 +96,11 @@ class DeviceManager { std::map rank_to_group_; // the key is rank list, value is hash name std::map group_to_rank_; // the key is hash name, value is rank list - int64_t local_rank_; - int64_t global_rank_; - int64_t stage_num_; - int64_t stage_id_; + int64_t global_rank_ = 0; // the real rank in all devices + int64_t stage_num_ = 0; // the stage num + int64_t stage_id_ = 0; // the stage id of the global_rank_ + int64_t rank_index_in_stage_ = 0; // the index of this rank in it's stage + int64_t stage_device_num_ = 0; // the device num of one stage }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc index cf3d37eb01..6568c54c24 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc @@ -232,7 +232,7 @@ Status GatherV2Info::InferTensorSubOps() { MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ")."; } int64_t mod_p = mod_n * dev_matrix_shape_.at(axis_); - int64_t rank = g_device_manager->global_rank(); + int64_t rank = g_device_manager->rank_index_in_stage(); int64_t mod_rank = rank % mod_p; mod_rank = static_cast(mod_rank / mod_n); if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index a7ef6dd7c1..3102327572 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -451,7 +451,7 @@ Status GatherV2PInfo::InferTensorInfo() { Shape input_shape = inputs_shape_.at(0); Shape input_index_shape = inputs_shape_.at(1); Shape output_shape = outputs_shape_.at(0); - int64_t rank = g_device_manager->global_rank(); + int64_t rank = g_device_manager->rank_index_in_stage(); // infer tensor layout TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; if (manual_split_) { @@ -481,7 +481,7 @@ Status GatherV2PInfo::InferTensorInfo() { Status GatherV2PInfo::InferBias() { CheckGlobalDeviceManager(); - int64_t rank = g_device_manager->global_rank(); + int64_t rank = g_device_manager->rank_index_in_stage(); auto input_shape = inputs_shape_.at(0); auto params_strategy = strategy_->GetInputDim().at(0); // axis don't split @@ -513,7 +513,7 @@ Status GatherV2PInfo::InferBias() { Status GatherV2PInfo::InferOffset() { CheckGlobalDeviceManager(); - size_t rank = g_device_manager->global_rank(); + size_t rank = g_device_manager->rank_index_in_stage(); MS_EXCEPTION_IF_NULL(strategy_); auto param_strategy = strategy_->GetInputDim()[0]; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc index 935b0650e3..457cd84420 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc @@ -134,7 +134,7 @@ Status OneHotInfo::InferTensorInfo() { Status OneHotInfo::ExtractInputInfo() { CheckGlobalDeviceManager(); - rank_ = g_device_manager->global_rank(); + rank_ = g_device_manager->rank_index_in_stage(); mod_rank_ = rank_ % old_dev_matrix_back_; if (!cnode_) { MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index 36c831bf9e..4cd500a71c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -116,7 +116,6 @@ void OperatorInfo::ResetQueueMember() { replace_op_.clear(); replace_op_info_.clear(); virtual_div_op_.clear(); - global_device_list_.clear(); } Status OperatorInfo::InferAttrs() { @@ -131,14 +130,8 @@ Status OperatorInfo::InferAttrs() { return SUCCESS; } -void OperatorInfo::SetDeviceListByStrategy() { - int64_t stage = strategy_->GetInputStage(); - CheckGlobalDeviceManager(); - global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); -} - Status OperatorInfo::InferRepeatedCalcInfo() { - int64_t g_dev_list_size = SizeToLong(global_device_list_.size()); + int64_t g_dev_list_size = stage_device_size_; int64_t dev_matrix_size = std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()); if (dev_matrix_size == 0) { @@ -155,12 +148,6 @@ Status OperatorInfo::InferRepeatedCalcInfo() { << dev_matrix_size; return FAILED; } - - CheckGlobalDeviceManager(); - int64_t rank = g_device_manager->global_rank(); - int64_t stage = strategy_->GetInputStage(); - local_device_list_ = g_device_manager->global_device_list(stage, rank, repeated_calc_num_); - return SUCCESS; } @@ -331,7 +318,7 @@ Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector } CheckGlobalDeviceManager(); int64_t rank = g_device_manager->global_rank(); - DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); + DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); RankList group_devices; if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) { return FAILED; @@ -354,7 +341,7 @@ Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector *group) { } CheckGlobalDeviceManager(); int64_t rank = g_device_manager->global_rank(); - DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); + DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); RankList group_devices; if (dev_matrix.GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) { return FAILED; @@ -469,7 +456,6 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat ResetQueueMember(); strategy_ = strategy; - SetDeviceListByStrategy(); if (InferDevMatrixShape() != SUCCESS) { MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; @@ -526,7 +512,6 @@ Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &str ResetQueueMember(); strategy_ = strategy; - SetDeviceListByStrategy(); if (InferDevMatrixShape() != SUCCESS) { MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; @@ -1325,7 +1310,7 @@ Status OperatorInfo::InferAsLossDivisor() { } if (outputs_tensor_map_[0].empty()) { - as_loss_divisor_ = SizeToLong(global_device_list_.size()); + as_loss_divisor_ = stage_device_size_; MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index f5b78075a6..813b64d13d 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -64,6 +64,8 @@ class OperatorInfo { std::vector not_parameteter(inputs_shape_.size(), false); is_parameter_ = not_parameteter; refkey_parameter_name_ = ""; + stage_device_list_ = g_device_manager->GetDeviceListInThisStage(); + stage_device_size_ = SizeToLong(stage_device_list_.size()); } virtual ~OperatorInfo() = default; @@ -119,7 +121,7 @@ class OperatorInfo { std::vector> strategy_cost() const { return strategy_cost_; } const std::string &name() const { return name_; } void set_name(const std::string &name) { name_ = name; } - RankList global_device_list() const { return global_device_list_; } + RankList stage_device_list() const { return stage_device_list_; } void AddSuccEdge(const std::shared_ptr &e) { succ_edges_.push_back(e); } void AddPrevEdge(const std::shared_ptr &e) { prev_edges_.push_back(e); } @@ -187,7 +189,6 @@ class OperatorInfo { virtual Status InferTensorInfo() = 0; virtual Status InferDevMatrixShape() = 0; Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); - void SetDeviceListByStrategy(); void SetRepeatedCalcDevMatrix(); void ResetTensorMapIfRepeatedCalc(); Status CreateGroupByDim(size_t axis, std::vector *group); @@ -231,8 +232,8 @@ class OperatorInfo { ReplaceGraphPtr replace_graph_; MirrorOps mirror_ops_; VirtualDivOp virtual_div_op_; - RankList global_device_list_; // the size of global_device_list equal to the size of stageID - RankList local_device_list_; // the size equal to global_device_list_.size() / repeated_calc_num_ + RankList stage_device_list_; // the device list in this stage + int64_t stage_device_size_ = 0; bool infer_attrs_completed_ = false; bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc index 73bd3eeb50..88b61dd790 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc @@ -136,7 +136,7 @@ Status RangeInfo::InitForCostModel(const StrategyPtr &strategy) { Status RangeInfo::InferNewAttr() { CheckGlobalDeviceManager(); - int64_t rank = g_device_manager->global_rank(); + int64_t rank = g_device_manager->rank_index_in_stage(); // If repeated calculation and repeated num as the last dimension of dev-matrix, // the dev-matrix is [split_num_, repeated_calc_num_], so from rank 0 to rank repeated_calc_num_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc index af357b0769..393fe2f751 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc @@ -531,7 +531,7 @@ Status ArgMaxWithValueInfo::InferAsLossDivisor() { MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer"; if (outputs_tensor_map_[0].empty()) { - as_loss_divisor_ = SizeToLong(global_device_list_.size()); + as_loss_divisor_ = stage_device_size_; MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor."; return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.cc index 5ae43f2e1f..3283436ef5 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.cc @@ -172,7 +172,7 @@ Status ReLUV2Info::InferAsLossDivisor() { } if (outputs_tensor_map_[0].empty()) { - as_loss_divisor_ = SizeToInt(global_device_list_.size()); + as_loss_divisor_ = stage_device_size_; MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc index 1ab418903c..d2fb83397a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc @@ -113,7 +113,7 @@ Status ReshapeInfo::GetParameterInput() { } Status ReshapeInfo::ComputeReplaceOp() { - RankList dev_list = global_device_list(); + RankList dev_list = stage_device_list(); TensorRedistribution tensor_redistribution(!is_generating_costs_, true); if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) { if (is_generating_costs_) { @@ -289,13 +289,7 @@ void ReshapeInfo::InferTensorInfoByLayout() { Status ReshapeInfo::GetAttrs() { return GetParameterInput(); } void ReshapeInfo::device_number(const StrategyPtr &strategy) { - int64_t stage = 0; - if (strategy != nullptr) { - stage = strategy->GetInputStage(); - } - CheckGlobalDeviceManager(); - global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); - dev_num_ = SizeToLong(global_device_list_.size()); + dev_num_ = stage_device_size_; MS_ASSERT(dev_num_ > 0); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc index a442ef6877..26acd68d6c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc @@ -260,7 +260,7 @@ Status SplitInfo::InferAsLossDivisor() { } if (outputs_tensor_map_[0].empty()) { - as_loss_divisor_ = SizeToInt(global_device_list_.size()); + as_loss_divisor_ = stage_device_size_; MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 28a7790e03..1a6f87df97 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -325,7 +325,7 @@ void Redistribution(const std::pair &node_pair, const Opera if (next_distribute_operator == nullptr) { MS_LOG(EXCEPTION) << "Failure: " << next_node->ToString() << " GetDistributeOperator failed"; } - RankList dev_list = distribute_operator->global_device_list(); + RankList dev_list = distribute_operator->stage_device_list(); std::string next_prim_name = GetValueNode(next_node->input(0))->name(); MS_LOG(DEBUG) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim " << next_prim_name; MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index 795cbd0372..b30afbfdb2 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -161,6 +161,8 @@ class EmbeddingLookup(Cell): Examples: >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) >>> out = nn.EmbeddingLookup(4,2)(input_indices) + >>> output.shape + (2, 2, 2) """ BATCH_SLICE = "batch_slice" FIELD_SLICE = "field_slice" diff --git a/tests/ut/cpp/parallel/device_manager_test.cc b/tests/ut/cpp/parallel/device_manager_test.cc index e63125e396..fe90375fb6 100644 --- a/tests/ut/cpp/parallel/device_manager_test.cc +++ b/tests/ut/cpp/parallel/device_manager_test.cc @@ -135,6 +135,8 @@ TEST_F(TestDeviceManager, test_StageID) { ASSERT_EQ(dm_.DeviceNum(), 4); ASSERT_EQ(dm_.stage_num(), 2); ASSERT_EQ(dm_.stage_id(), 1); + ASSERT_EQ(dm_.rank_index_in_stage(), 0); + ASSERT_EQ(dm_.GetDeviceListInThisStage().back(), 3); RankList dev_list_0 = dm_.GetDeviceListByStageId(0); RankList dev_list_1 = dm_.GetDeviceListByStageId(1); diff --git a/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc b/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc index 3468d921b0..0b7b5476c5 100644 --- a/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc @@ -171,7 +171,7 @@ TEST_F(TestLogSoftmaxInfo, GetDeviceList1) { StrategyPtr strategy = NewStrategy(0, inputs); log_softmax->Init(strategy); - RankList dev_list = log_softmax->global_device_list(); + RankList dev_list = log_softmax->stage_device_list(); ASSERT_EQ(dev_list.size(), 128); }