!8644 update getting device list in parallel ops

From: @yangzhenzhang
Reviewed-by: @stsuteng,@kisnwang
Signed-off-by: @stsuteng
pull/8644/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e1cfeeb1dd

@ -75,6 +75,7 @@ class DeviceManager {
size_t DeviceNum() const { return devices_.size(); }
int64_t stage_num() const { return stage_num_; }
int64_t stage_device_num() const { return stage_device_num_; }
int64_t stage_id() const { return stage_id_; }
int64_t rank_index_in_stage() const { return rank_index_in_stage_; }
int64_t global_rank() const { return global_rank_; }

@ -41,11 +41,9 @@ Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) {
}
// dropout don't support repeated calculation
CheckGlobalDeviceManager();
auto input_strategy = strategy->GetInputDim().at(0);
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int64_t>());
if (IntToSize(product_p) != dev_num) {
if (product_p != stage_device_size_) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";
return FAILED;
}

@ -32,11 +32,6 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}
int64_t stage = strategy->GetInputStage();
CheckGlobalDeviceManager();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size());
dev_num_ = dev_num;
size_t strategy_size = strategy->GetInputNumber();
Strategys stra = strategy->GetInputDim();
for (size_t i = 0; i < strategy_size; ++i) {
@ -46,7 +41,7 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
for (size_t j = 0; j < strategy_len; ++j) {
int64_t strategy_value = sub_strategy.at(j);
if (strategy_value > 1) {
if (flag || strategy_value != dev_num_) {
if (flag || strategy_value != stage_device_size_) {
MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy.";
return FAILED;
}
@ -58,7 +53,7 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
}
Status BatchParallelInfo::InferDevMatrixShape() {
dev_matrix_shape_.push_back(dev_num_);
dev_matrix_shape_.push_back(stage_device_size_);
return SUCCESS;
}
@ -81,14 +76,14 @@ Status BatchParallelInfo::InferMirrorOps() {
Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; }
Status BatchParallelInfo::InferTensorMap() {
if (strategy_->GetInputDim()[0][0] != dev_num_) {
if (strategy_->GetInputDim()[0][0] != stage_device_size_) {
MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy.";
return FAILED;
}
for (size_t i = 0; i < inputs_shape_.size(); i++) {
Shape tensor_map_index;
for (size_t j = 0; j < inputs_shape_[i].size(); ++j) {
if (strategy_->GetInputDim()[i][j] == dev_num_ && j == 0) {
if (strategy_->GetInputDim()[i][j] == stage_device_size_ && j == 0) {
tensor_map_index.push_back(0);
} else {
tensor_map_index.push_back(MAP_NONE);
@ -117,7 +112,7 @@ Strategys BatchParallelInfo::GetOutputsStrategy() {
Dimensions strategy;
for (size_t j = 0; j < outputs_shape_[i].size(); ++j) {
if (i == 0 && j == 0) {
strategy.push_back(dev_num_);
strategy.push_back(stage_device_size_);
} else {
strategy.push_back(1);
}
@ -176,14 +171,12 @@ Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
}
Status BatchParallelInfo::GenerateStrategies(int64_t stage_id) {
CheckGlobalDeviceManager();
size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
StrategyPtr sp;
Strategys strategy;
for (size_t i = 0; i < inputs_shape_.size(); i++) {
Shape temp(inputs_shape_[i].size(), 1);
if (split_flag_list_[i]) {
temp[0] = SizeToLong(total_dev_num);
temp[0] = stage_device_size_;
}
strategy.push_back(temp);
}

@ -151,10 +151,8 @@ Status DropoutDoMaskInfo::GenerateStrategies(int64_t stage_id) {
}
std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() {
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions strategy(inputs_shape_[0].size() - 1, 1);
(void)strategy.insert(strategy.begin(), SizeToLong(dev_num));
(void)strategy.insert(strategy.begin(), stage_device_size_);
Strategys strategy_v = {strategy};
return std::make_shared<Strategys>(strategy_v);
}

@ -308,8 +308,6 @@ std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() {
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();
}
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << "GetAttrs failed!";
}
@ -318,7 +316,7 @@ std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() {
if (index_size_ != 1) {
strategy.push_back(1);
} else {
strategy.push_back(SizeToLong(dev_num));
strategy.push_back(stage_device_size_);
}
for (size_t i = 1; i < inputs_shape_[0].size(); i++) {
strategy.push_back(1);

@ -199,10 +199,8 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
}
// Don't support repeated calc
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
if (IntToSize(product_p) < dev_num) {
if (product_p < stage_device_size_) {
MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
return FAILED;
}
@ -272,10 +270,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
}
// param_strategy(axis) != 1, Don't support repeated calc
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) {
if (product_p != stage_device_size_ && param_strategy.at(IntToSize(axis_)) != 1) {
MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
return FAILED;
}
@ -349,13 +345,11 @@ Status GatherV2PInfo::InferDevMatrixShape() {
} else {
out_dev_matrix_shape_ = dev_matrix_shape_;
}
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>());
if (param_product * index_product < SizeToInt(dev_num)) {
if (param_product * index_product < stage_device_size_) {
// add the repeated calculation num to the last dimension of dev matrix
out_dev_matrix_shape_.push_back(SizeToInt(dev_num / (param_product * index_product)));
out_dev_matrix_shape_.push_back(stage_device_size_ / (param_product * index_product));
}
return SUCCESS;
@ -539,11 +533,8 @@ Status GatherV2PInfo::InferGroup() {
dim = (axis_ + 1) % 2;
}
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id_);
int64_t rank = g_device_manager->global_rank();
DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_);
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
RankList group_devices;
if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group failed.";
@ -777,11 +768,10 @@ std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() {
if (manual_split_) {
MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy";
}
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions param_strategy(inputs_shape_[0].size(), 1);
Dimensions index_strategy;
index_strategy.push_back(SizeToLong(dev_num));
index_strategy.push_back(stage_device_size_);
for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
index_strategy.push_back(1);
}

@ -66,7 +66,7 @@ Strategys GetNextInfo::GetOutputStrategy() {
Strategys outputs_strategy;
for (auto shp : shapes_) {
Dimensions out_strategy;
out_strategy.push_back(dev_num_);
out_strategy.push_back(stage_device_size_);
for (size_t i = 1; i < shp.size(); ++i) {
out_strategy.push_back(1);
}
@ -97,7 +97,7 @@ Status GetNextInfo::InferDevMatrixShape() {
if (max_shape_length == 0) {
MS_LOG(ERROR) << name_ << " : shape is 0";
}
dev_matrix_shape_.push_back(dev_num_);
dev_matrix_shape_.push_back(stage_device_size_);
for (size_t i = 1; i < max_shape_length; ++i) {
dev_matrix_shape_.push_back(1);
}
@ -125,9 +125,6 @@ Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}
}
int64_t stage = strategy->GetInputStage();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size());
dev_num_ = dev_num;
return SUCCESS;
}
@ -199,16 +196,16 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) {
Shapes out_shapes = outputs_shape_;
for (size_t i = 0; i < out_shapes.size(); ++i) {
if (dev_num_ <= 0) {
if (stage_device_size_ <= 0) {
MS_LOG(ERROR) << name_ << " : The dev num is 0.";
return FAILED;
}
if (!full_batch) {
if (out_shapes[i][0] % dev_num_ != 0) {
if (out_shapes[i][0] % stage_device_size_ != 0) {
MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num.";
return FAILED;
}
out_shapes[i][0] = out_shapes[i][0] / dev_num_;
out_shapes[i][0] = out_shapes[i][0] / stage_device_size_;
}
}
ValuePtr new_shapes = MakeValue(out_shapes);

@ -601,10 +601,8 @@ Status MatMulBase::CheckForTensorSliceValid() const {
}
std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() {
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1);
batch_strategy.insert(batch_strategy.begin(), SizeToLong(dev_num));
batch_strategy.insert(batch_strategy.begin(), stage_device_size_);
Strategys strategy_v = {batch_strategy, batch_strategy};
return std::make_shared<Strategys>(strategy_v);
}

@ -268,9 +268,7 @@ Status OneHotInfo::GenerateStrategies(int64_t stage_id) {
Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() {
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions strategy = {SizeToLong(dev_num), 1};
Dimensions strategy = {stage_device_size_, 1};
Dimensions empty_strategy;
Strategys strategy_v = {strategy, empty_strategy, empty_strategy};
return std::make_shared<Strategys>(strategy_v);

@ -688,7 +688,7 @@ std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shap
return nullptr;
}
CheckGlobalDeviceManager();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size());
int64_t dev_num = g_device_manager->stage_device_num();
Strategys strategy_v;
for (size_t i = 0; i != shapes.size(); i++) {
if (shapes[i].empty()) {
@ -1393,9 +1393,7 @@ Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type)
void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) {
if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) {
CheckGlobalDeviceManager();
auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size();
if (LongToSize(stra->GetInputDim()[0][0]) == total_device_num) {
if (stra->GetInputDim()[0][0] == stage_device_size_) {
if (cost->computation_cost_ > 1.0) {
cost->computation_cost_ -= 1.0;
}

@ -233,15 +233,13 @@ std::shared_ptr<Strategys> SplitInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
}
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions input_strategy(inputs_shape_[0].size(), 1);
// axis can't split
if (inputs_shape_[0].size() > 1) {
if (axis_ == 0) {
input_strategy[1] = dev_num;
input_strategy[1] = stage_device_size_;
} else {
input_strategy[0] = dev_num;
input_strategy[0] = stage_device_size_;
}
}
Strategys strategy_v = {input_strategy};

@ -408,17 +408,14 @@ std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
}
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions input_a_strategy(inputs_shape_[0].size(), 1);
Dimensions input_b_strategy(inputs_shape_[1].size(), 1);
input_a_strategy[0] = SizeToInt(dev_num);
input_a_strategy[0] = stage_device_size_;
if (axes_type_ == INT_TYPE) {
if (IntToSize(axes_int_) == inputs_shape_[0].size()) {
input_b_strategy[0] = SizeToInt(dev_num); // find the relavent dimension for input_b
input_b_strategy[0] = stage_device_size_; // find the relavent dimension for input_b
}
} else if (axes_type_ == TUPLE_TUPLE_TYPE) {
// if the input_a's axes contain 0, the input_b has the relavent dimension with batch dimension
@ -434,7 +431,7 @@ std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() {
if (found) {
// find the relavant
input_b_strategy[axes_tuple_tuple_[1][relavant_index]] = dev_num;
input_b_strategy[axes_tuple_tuple_[1][relavant_index]] = stage_device_size_;
}
} else {
MS_LOG(EXCEPTION) << name_ << ": Now do not support TUPLE_TYPE";

@ -85,7 +85,7 @@ Status UniqueInfo::InferTensorInfo() {
}
Status UniqueInfo::InferDevMatrixShape() {
dev_matrix_shape_.push_back(dev_num_);
dev_matrix_shape_.push_back(stage_device_size_);
return SUCCESS;
}
@ -110,9 +110,7 @@ Status UniqueInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}
}
int64_t stage = strategy->GetInputStage();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size());
dev_num_ = dev_num;
if (stras[0][0] != 1) {
MS_LOG(ERROR) << "Currently, unique only support repeat calculate in all devices";
return FAILED;

@ -277,20 +277,17 @@ std::shared_ptr<Strategys> UnsortedSegmentOpInfo::GenerateBatchStrategies() {
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();
}
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << "GetAttrs failed!";
}
Dimensions strategy_a;
Dimensions strategy_b;
strategy_a.push_back(SizeToInt(dev_num));
Dimensions strategy_a, strategy_b;
strategy_a.push_back(stage_device_size_);
for (size_t i = 1; i < inputs_shape_[0].size(); i++) {
strategy_a.push_back(1);
}
strategy_b.push_back(SizeToInt(dev_num));
strategy_b.push_back(stage_device_size_);
for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
strategy_b.push_back(1);
}

@ -66,13 +66,10 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
Status VirtualDatasetInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim();
Dimensions strategy_first = stra.at(0);
int64_t stage = strategy_->GetInputStage();
CheckGlobalDeviceManager();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size());
int64_t batch_split_num = ((int64_t)(strategy_first.at(0)));
dev_matrix_shape_.push_back(batch_split_num);
if (dev_num > batch_split_num) {
dev_matrix_shape_.push_back(dev_num / batch_split_num);
if (stage_device_size_ > batch_split_num) {
dev_matrix_shape_.push_back(stage_device_size_ / batch_split_num);
}
return SUCCESS;
@ -156,11 +153,10 @@ Status VirtualDatasetInfo::GenerateStrategies(int64_t stage_id) {
return FAILED;
}
CheckGlobalDeviceManager();
if (full_batch) {
total_dev_num = 1;
} else {
total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
total_dev_num = stage_device_size_;
}
StrategyPtr sp;
Strategys strategy;

@ -1640,7 +1640,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
if (full_batch) {
dev_num = 1;
} else {
dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size());
dev_num = SizeToLong(g_device_manager->stage_device_num());
}
auto attrs_temp = prim->attrs();
std::vector<Shapes> shape_list = ExtractShape(node);
@ -1984,7 +1984,7 @@ std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
return next_layout;
}
CheckGlobalDeviceManager();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size());
int64_t dev_num = g_device_manager->stage_device_num();
TensorLayout input_tensor_layout;
// create input_shape
Shapes inputs_shape = GetNodeShape(node);
@ -2009,7 +2009,7 @@ RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const Te
TensorRedistribution tensor_redistribution;
// create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num].
CheckGlobalDeviceManager();
int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size());
int64_t dev_num = g_device_manager->stage_device_num();
TensorLayout stand_alone_layout;
Shapes inputs_shape = GetNodeShape(node);
if (inputs_shape.empty()) {
@ -2029,7 +2029,7 @@ RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const Te
}
// Infer Redistribution op list for stand alone and loss layout.
RankList dev_list = g_device_manager->GetDeviceListByStageId(0);
RankList dev_list = g_device_manager->GetDeviceListInThisStage();
if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) {
MS_LOG(EXCEPTION) << "Redistribution for Sens init failed.";
}
@ -3093,7 +3093,7 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) {
if (full_batch) {
return;
}
auto dev_num = g_device_manager->GetDeviceListByStageId(0).size();
auto dev_num = g_device_manager->stage_device_num();
auto parameters = root->parameters();
for (auto &parameter : parameters) {
if (IsUsedParameter(root, parameter)) {

Loading…
Cancel
Save