!8644 update getting device list in parallel ops

From: @yangzhenzhang
Reviewed-by: @stsuteng,@kisnwang
Signed-off-by: @stsuteng
pull/8644/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e1cfeeb1dd

@ -75,6 +75,7 @@ class DeviceManager {
size_t DeviceNum() const { return devices_.size(); } size_t DeviceNum() const { return devices_.size(); }
int64_t stage_num() const { return stage_num_; } int64_t stage_num() const { return stage_num_; }
int64_t stage_device_num() const { return stage_device_num_; }
int64_t stage_id() const { return stage_id_; } int64_t stage_id() const { return stage_id_; }
int64_t rank_index_in_stage() const { return rank_index_in_stage_; } int64_t rank_index_in_stage() const { return rank_index_in_stage_; }
int64_t global_rank() const { return global_rank_; } int64_t global_rank() const { return global_rank_; }

@ -41,11 +41,9 @@ Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
// dropout don't support repeated calculation // dropout don't support repeated calculation
CheckGlobalDeviceManager();
auto input_strategy = strategy->GetInputDim().at(0); auto input_strategy = strategy->GetInputDim().at(0);
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int64_t>()); auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int64_t>());
if (IntToSize(product_p) != dev_num) { if (product_p != stage_device_size_) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";
return FAILED; return FAILED;
} }

@ -32,11 +32,6 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
int64_t stage = strategy->GetInputStage();
CheckGlobalDeviceManager();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size());
dev_num_ = dev_num;
size_t strategy_size = strategy->GetInputNumber(); size_t strategy_size = strategy->GetInputNumber();
Strategys stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
for (size_t i = 0; i < strategy_size; ++i) { for (size_t i = 0; i < strategy_size; ++i) {
@ -46,7 +41,7 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
for (size_t j = 0; j < strategy_len; ++j) { for (size_t j = 0; j < strategy_len; ++j) {
int64_t strategy_value = sub_strategy.at(j); int64_t strategy_value = sub_strategy.at(j);
if (strategy_value > 1) { if (strategy_value > 1) {
if (flag || strategy_value != dev_num_) { if (flag || strategy_value != stage_device_size_) {
MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy.";
return FAILED; return FAILED;
} }
@ -58,7 +53,7 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status BatchParallelInfo::InferDevMatrixShape() { Status BatchParallelInfo::InferDevMatrixShape() {
dev_matrix_shape_.push_back(dev_num_); dev_matrix_shape_.push_back(stage_device_size_);
return SUCCESS; return SUCCESS;
} }
@ -81,14 +76,14 @@ Status BatchParallelInfo::InferMirrorOps() {
Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; }
Status BatchParallelInfo::InferTensorMap() { Status BatchParallelInfo::InferTensorMap() {
if (strategy_->GetInputDim()[0][0] != dev_num_) { if (strategy_->GetInputDim()[0][0] != stage_device_size_) {
MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy.";
return FAILED; return FAILED;
} }
for (size_t i = 0; i < inputs_shape_.size(); i++) { for (size_t i = 0; i < inputs_shape_.size(); i++) {
Shape tensor_map_index; Shape tensor_map_index;
for (size_t j = 0; j < inputs_shape_[i].size(); ++j) { for (size_t j = 0; j < inputs_shape_[i].size(); ++j) {
if (strategy_->GetInputDim()[i][j] == dev_num_ && j == 0) { if (strategy_->GetInputDim()[i][j] == stage_device_size_ && j == 0) {
tensor_map_index.push_back(0); tensor_map_index.push_back(0);
} else { } else {
tensor_map_index.push_back(MAP_NONE); tensor_map_index.push_back(MAP_NONE);
@ -117,7 +112,7 @@ Strategys BatchParallelInfo::GetOutputsStrategy() {
Dimensions strategy; Dimensions strategy;
for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { for (size_t j = 0; j < outputs_shape_[i].size(); ++j) {
if (i == 0 && j == 0) { if (i == 0 && j == 0) {
strategy.push_back(dev_num_); strategy.push_back(stage_device_size_);
} else { } else {
strategy.push_back(1); strategy.push_back(1);
} }
@ -176,14 +171,12 @@ Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
} }
Status BatchParallelInfo::GenerateStrategies(int64_t stage_id) { Status BatchParallelInfo::GenerateStrategies(int64_t stage_id) {
CheckGlobalDeviceManager();
size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
StrategyPtr sp; StrategyPtr sp;
Strategys strategy; Strategys strategy;
for (size_t i = 0; i < inputs_shape_.size(); i++) { for (size_t i = 0; i < inputs_shape_.size(); i++) {
Shape temp(inputs_shape_[i].size(), 1); Shape temp(inputs_shape_[i].size(), 1);
if (split_flag_list_[i]) { if (split_flag_list_[i]) {
temp[0] = SizeToLong(total_dev_num); temp[0] = stage_device_size_;
} }
strategy.push_back(temp); strategy.push_back(temp);
} }

@ -151,10 +151,8 @@ Status DropoutDoMaskInfo::GenerateStrategies(int64_t stage_id) {
} }
std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() { std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() {
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions strategy(inputs_shape_[0].size() - 1, 1); Dimensions strategy(inputs_shape_[0].size() - 1, 1);
(void)strategy.insert(strategy.begin(), SizeToLong(dev_num)); (void)strategy.insert(strategy.begin(), stage_device_size_);
Strategys strategy_v = {strategy}; Strategys strategy_v = {strategy};
return std::make_shared<Strategys>(strategy_v); return std::make_shared<Strategys>(strategy_v);
} }

@ -308,8 +308,6 @@ std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() {
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_shape_.size(); << inputs_shape_.size();
} }
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
if (GetAttrs() != SUCCESS) { if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << "GetAttrs failed!"; MS_LOG(EXCEPTION) << "GetAttrs failed!";
} }
@ -318,7 +316,7 @@ std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() {
if (index_size_ != 1) { if (index_size_ != 1) {
strategy.push_back(1); strategy.push_back(1);
} else { } else {
strategy.push_back(SizeToLong(dev_num)); strategy.push_back(stage_device_size_);
} }
for (size_t i = 1; i < inputs_shape_[0].size(); i++) { for (size_t i = 1; i < inputs_shape_[0].size(); i++) {
strategy.push_back(1); strategy.push_back(1);

@ -199,10 +199,8 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
} }
// Don't support repeated calc // Don't support repeated calc
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>()); auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
if (IntToSize(product_p) < dev_num) { if (product_p < stage_device_size_) {
MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc"; MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
return FAILED; return FAILED;
} }
@ -272,10 +270,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
// param_strategy(axis) != 1, Don't support repeated calc // param_strategy(axis) != 1, Don't support repeated calc
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>()); auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { if (product_p != stage_device_size_ && param_strategy.at(IntToSize(axis_)) != 1) {
MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
return FAILED; return FAILED;
} }
@ -349,13 +345,11 @@ Status GatherV2PInfo::InferDevMatrixShape() {
} else { } else {
out_dev_matrix_shape_ = dev_matrix_shape_; out_dev_matrix_shape_ = dev_matrix_shape_;
} }
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>()); auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>()); auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>());
if (param_product * index_product < SizeToInt(dev_num)) { if (param_product * index_product < stage_device_size_) {
// add the repeated calculation num to the last dimension of dev matrix // add the repeated calculation num to the last dimension of dev matrix
out_dev_matrix_shape_.push_back(SizeToInt(dev_num / (param_product * index_product))); out_dev_matrix_shape_.push_back(stage_device_size_ / (param_product * index_product));
} }
return SUCCESS; return SUCCESS;
@ -539,11 +533,8 @@ Status GatherV2PInfo::InferGroup() {
dim = (axis_ + 1) % 2; dim = (axis_ + 1) % 2;
} }
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id_);
int64_t rank = g_device_manager->global_rank(); int64_t rank = g_device_manager->global_rank();
DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_); DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
RankList group_devices; RankList group_devices;
if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) { if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group failed."; MS_LOG(ERROR) << name_ << ": Create group failed.";
@ -777,11 +768,10 @@ std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() {
if (manual_split_) { if (manual_split_) {
MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy"; MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy";
} }
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions param_strategy(inputs_shape_[0].size(), 1); Dimensions param_strategy(inputs_shape_[0].size(), 1);
Dimensions index_strategy; Dimensions index_strategy;
index_strategy.push_back(SizeToLong(dev_num)); index_strategy.push_back(stage_device_size_);
for (size_t i = 1; i < inputs_shape_[1].size(); i++) { for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
index_strategy.push_back(1); index_strategy.push_back(1);
} }

@ -66,7 +66,7 @@ Strategys GetNextInfo::GetOutputStrategy() {
Strategys outputs_strategy; Strategys outputs_strategy;
for (auto shp : shapes_) { for (auto shp : shapes_) {
Dimensions out_strategy; Dimensions out_strategy;
out_strategy.push_back(dev_num_); out_strategy.push_back(stage_device_size_);
for (size_t i = 1; i < shp.size(); ++i) { for (size_t i = 1; i < shp.size(); ++i) {
out_strategy.push_back(1); out_strategy.push_back(1);
} }
@ -97,7 +97,7 @@ Status GetNextInfo::InferDevMatrixShape() {
if (max_shape_length == 0) { if (max_shape_length == 0) {
MS_LOG(ERROR) << name_ << " : shape is 0"; MS_LOG(ERROR) << name_ << " : shape is 0";
} }
dev_matrix_shape_.push_back(dev_num_); dev_matrix_shape_.push_back(stage_device_size_);
for (size_t i = 1; i < max_shape_length; ++i) { for (size_t i = 1; i < max_shape_length; ++i) {
dev_matrix_shape_.push_back(1); dev_matrix_shape_.push_back(1);
} }
@ -125,9 +125,6 @@ Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
} }
int64_t stage = strategy->GetInputStage();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size());
dev_num_ = dev_num;
return SUCCESS; return SUCCESS;
} }
@ -199,16 +196,16 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) {
Shapes out_shapes = outputs_shape_; Shapes out_shapes = outputs_shape_;
for (size_t i = 0; i < out_shapes.size(); ++i) { for (size_t i = 0; i < out_shapes.size(); ++i) {
if (dev_num_ <= 0) { if (stage_device_size_ <= 0) {
MS_LOG(ERROR) << name_ << " : The dev num is 0."; MS_LOG(ERROR) << name_ << " : The dev num is 0.";
return FAILED; return FAILED;
} }
if (!full_batch) { if (!full_batch) {
if (out_shapes[i][0] % dev_num_ != 0) { if (out_shapes[i][0] % stage_device_size_ != 0) {
MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num.";
return FAILED; return FAILED;
} }
out_shapes[i][0] = out_shapes[i][0] / dev_num_; out_shapes[i][0] = out_shapes[i][0] / stage_device_size_;
} }
} }
ValuePtr new_shapes = MakeValue(out_shapes); ValuePtr new_shapes = MakeValue(out_shapes);

@ -601,10 +601,8 @@ Status MatMulBase::CheckForTensorSliceValid() const {
} }
std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() { std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() {
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1); Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1);
batch_strategy.insert(batch_strategy.begin(), SizeToLong(dev_num)); batch_strategy.insert(batch_strategy.begin(), stage_device_size_);
Strategys strategy_v = {batch_strategy, batch_strategy}; Strategys strategy_v = {batch_strategy, batch_strategy};
return std::make_shared<Strategys>(strategy_v); return std::make_shared<Strategys>(strategy_v);
} }

@ -268,9 +268,7 @@ Status OneHotInfo::GenerateStrategies(int64_t stage_id) {
Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() { std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() {
CheckGlobalDeviceManager(); Dimensions strategy = {stage_device_size_, 1};
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions strategy = {SizeToLong(dev_num), 1};
Dimensions empty_strategy; Dimensions empty_strategy;
Strategys strategy_v = {strategy, empty_strategy, empty_strategy}; Strategys strategy_v = {strategy, empty_strategy, empty_strategy};
return std::make_shared<Strategys>(strategy_v); return std::make_shared<Strategys>(strategy_v);

@ -688,7 +688,7 @@ std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shap
return nullptr; return nullptr;
} }
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); int64_t dev_num = g_device_manager->stage_device_num();
Strategys strategy_v; Strategys strategy_v;
for (size_t i = 0; i != shapes.size(); i++) { for (size_t i = 0; i != shapes.size(); i++) {
if (shapes[i].empty()) { if (shapes[i].empty()) {
@ -1393,9 +1393,7 @@ Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type)
void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) { void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) {
if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) { if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) {
CheckGlobalDeviceManager(); if (stra->GetInputDim()[0][0] == stage_device_size_) {
auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size();
if (LongToSize(stra->GetInputDim()[0][0]) == total_device_num) {
if (cost->computation_cost_ > 1.0) { if (cost->computation_cost_ > 1.0) {
cost->computation_cost_ -= 1.0; cost->computation_cost_ -= 1.0;
} }

@ -233,15 +233,13 @@ std::shared_ptr<Strategys> SplitInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) { if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
} }
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions input_strategy(inputs_shape_[0].size(), 1); Dimensions input_strategy(inputs_shape_[0].size(), 1);
// axis can't split // axis can't split
if (inputs_shape_[0].size() > 1) { if (inputs_shape_[0].size() > 1) {
if (axis_ == 0) { if (axis_ == 0) {
input_strategy[1] = dev_num; input_strategy[1] = stage_device_size_;
} else { } else {
input_strategy[0] = dev_num; input_strategy[0] = stage_device_size_;
} }
} }
Strategys strategy_v = {input_strategy}; Strategys strategy_v = {input_strategy};

@ -408,17 +408,14 @@ std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) { if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
} }
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions input_a_strategy(inputs_shape_[0].size(), 1); Dimensions input_a_strategy(inputs_shape_[0].size(), 1);
Dimensions input_b_strategy(inputs_shape_[1].size(), 1); Dimensions input_b_strategy(inputs_shape_[1].size(), 1);
input_a_strategy[0] = SizeToInt(dev_num); input_a_strategy[0] = stage_device_size_;
if (axes_type_ == INT_TYPE) { if (axes_type_ == INT_TYPE) {
if (IntToSize(axes_int_) == inputs_shape_[0].size()) { if (IntToSize(axes_int_) == inputs_shape_[0].size()) {
input_b_strategy[0] = SizeToInt(dev_num); // find the relavent dimension for input_b input_b_strategy[0] = stage_device_size_; // find the relavent dimension for input_b
} }
} else if (axes_type_ == TUPLE_TUPLE_TYPE) { } else if (axes_type_ == TUPLE_TUPLE_TYPE) {
// if the input_a's axes contain 0, the input_b has the relavent dimension with batch dimension // if the input_a's axes contain 0, the input_b has the relavent dimension with batch dimension
@ -434,7 +431,7 @@ std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() {
if (found) { if (found) {
// find the relavant // find the relavant
input_b_strategy[axes_tuple_tuple_[1][relavant_index]] = dev_num; input_b_strategy[axes_tuple_tuple_[1][relavant_index]] = stage_device_size_;
} }
} else { } else {
MS_LOG(EXCEPTION) << name_ << ": Now do not support TUPLE_TYPE"; MS_LOG(EXCEPTION) << name_ << ": Now do not support TUPLE_TYPE";

@ -85,7 +85,7 @@ Status UniqueInfo::InferTensorInfo() {
} }
Status UniqueInfo::InferDevMatrixShape() { Status UniqueInfo::InferDevMatrixShape() {
dev_matrix_shape_.push_back(dev_num_); dev_matrix_shape_.push_back(stage_device_size_);
return SUCCESS; return SUCCESS;
} }
@ -110,9 +110,7 @@ Status UniqueInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
} }
int64_t stage = strategy->GetInputStage();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size());
dev_num_ = dev_num;
if (stras[0][0] != 1) { if (stras[0][0] != 1) {
MS_LOG(ERROR) << "Currently, unique only support repeat calculate in all devices"; MS_LOG(ERROR) << "Currently, unique only support repeat calculate in all devices";
return FAILED; return FAILED;

@ -277,20 +277,17 @@ std::shared_ptr<Strategys> UnsortedSegmentOpInfo::GenerateBatchStrategies() {
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is " MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is "
<< inputs_shape_.size(); << inputs_shape_.size();
} }
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
if (GetAttrs() != SUCCESS) { if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << "GetAttrs failed!"; MS_LOG(EXCEPTION) << "GetAttrs failed!";
} }
Dimensions strategy_a; Dimensions strategy_a, strategy_b;
Dimensions strategy_b; strategy_a.push_back(stage_device_size_);
strategy_a.push_back(SizeToInt(dev_num));
for (size_t i = 1; i < inputs_shape_[0].size(); i++) { for (size_t i = 1; i < inputs_shape_[0].size(); i++) {
strategy_a.push_back(1); strategy_a.push_back(1);
} }
strategy_b.push_back(SizeToInt(dev_num)); strategy_b.push_back(stage_device_size_);
for (size_t i = 1; i < inputs_shape_[1].size(); i++) { for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
strategy_b.push_back(1); strategy_b.push_back(1);
} }

@ -66,13 +66,10 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
Status VirtualDatasetInfo::InferDevMatrixShape() { Status VirtualDatasetInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
Dimensions strategy_first = stra.at(0); Dimensions strategy_first = stra.at(0);
int64_t stage = strategy_->GetInputStage();
CheckGlobalDeviceManager();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size());
int64_t batch_split_num = ((int64_t)(strategy_first.at(0))); int64_t batch_split_num = ((int64_t)(strategy_first.at(0)));
dev_matrix_shape_.push_back(batch_split_num); dev_matrix_shape_.push_back(batch_split_num);
if (dev_num > batch_split_num) { if (stage_device_size_ > batch_split_num) {
dev_matrix_shape_.push_back(dev_num / batch_split_num); dev_matrix_shape_.push_back(stage_device_size_ / batch_split_num);
} }
return SUCCESS; return SUCCESS;
@ -156,11 +153,10 @@ Status VirtualDatasetInfo::GenerateStrategies(int64_t stage_id) {
return FAILED; return FAILED;
} }
CheckGlobalDeviceManager();
if (full_batch) { if (full_batch) {
total_dev_num = 1; total_dev_num = 1;
} else { } else {
total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); total_dev_num = stage_device_size_;
} }
StrategyPtr sp; StrategyPtr sp;
Strategys strategy; Strategys strategy;

@ -1640,7 +1640,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
if (full_batch) { if (full_batch) {
dev_num = 1; dev_num = 1;
} else { } else {
dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); dev_num = SizeToLong(g_device_manager->stage_device_num());
} }
auto attrs_temp = prim->attrs(); auto attrs_temp = prim->attrs();
std::vector<Shapes> shape_list = ExtractShape(node); std::vector<Shapes> shape_list = ExtractShape(node);
@ -1984,7 +1984,7 @@ std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
return next_layout; return next_layout;
} }
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); int64_t dev_num = g_device_manager->stage_device_num();
TensorLayout input_tensor_layout; TensorLayout input_tensor_layout;
// create input_shape // create input_shape
Shapes inputs_shape = GetNodeShape(node); Shapes inputs_shape = GetNodeShape(node);
@ -2009,7 +2009,7 @@ RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const Te
TensorRedistribution tensor_redistribution; TensorRedistribution tensor_redistribution;
// create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num]. // create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num].
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); int64_t dev_num = g_device_manager->stage_device_num();
TensorLayout stand_alone_layout; TensorLayout stand_alone_layout;
Shapes inputs_shape = GetNodeShape(node); Shapes inputs_shape = GetNodeShape(node);
if (inputs_shape.empty()) { if (inputs_shape.empty()) {
@ -2029,7 +2029,7 @@ RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const Te
} }
// Infer Redistribution op list for stand alone and loss layout. // Infer Redistribution op list for stand alone and loss layout.
RankList dev_list = g_device_manager->GetDeviceListByStageId(0); RankList dev_list = g_device_manager->GetDeviceListInThisStage();
if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) { if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) {
MS_LOG(EXCEPTION) << "Redistribution for Sens init failed."; MS_LOG(EXCEPTION) << "Redistribution for Sens init failed.";
} }
@ -3093,7 +3093,7 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) {
if (full_batch) { if (full_batch) {
return; return;
} }
auto dev_num = g_device_manager->GetDeviceListByStageId(0).size(); auto dev_num = g_device_manager->stage_device_num();
auto parameters = root->parameters(); auto parameters = root->parameters();
for (auto &parameter : parameters) { for (auto &parameter : parameters) {
if (IsUsedParameter(root, parameter)) { if (IsUsedParameter(root, parameter)) {

Loading…
Cancel
Save