From 70aa0dc5e233a6907c5d6bb6a9b3a5eef0365ad0 Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Fri, 19 Feb 2021 14:30:41 +0800 Subject: [PATCH] modify get output layout --- .../ccsrc/frontend/parallel/step_parallel.cc | 293 ++++++++++-------- 1 file changed, 160 insertions(+), 133 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 83faf919ed..8f6c313b9a 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -956,64 +956,71 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node } } +static std::pair FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { + if (IsValueNode(node)) { + std::vector param_v = FindParameterByRefKeyNode(node, func_graph); + if (param_v.size() != 1) { + MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is " + << param_v.size(); + } + auto param_ptr = param_v[0]->user_data(); + if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { + return std::make_pair(nullptr, true); + } + return std::make_pair(node, true); + } + return std::make_pair(nullptr, false); +} + // Only used for InsertMirrorOps std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { if (!node->isa() && !node->isa() && !node->isa()) { return std::make_pair(nullptr, false); - } else if (node->isa()) { + } + + if (node->isa()) { auto param_ptr = node->user_data(); if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { return std::make_pair(nullptr, false); - } else { - return std::make_pair(node, false); - } - } else if (node->isa()) { - if (IsValueNode(node)) { - std::vector param_v = FindParameterByRefKeyNode(node, func_graph); - if (param_v.size() != 1) { - MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is " - << param_v.size(); - } - auto param_ptr = param_v[0]->user_data(); - if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { - return std::make_pair(nullptr, true); - } else { - return std::make_pair(node, true); + } + return std::make_pair(node, false); + } + + if (node->isa()) { + return FindParameterByValueNode(node, func_graph); + } + + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsValueNode(cnode->input(0))) { + for (size_t index = 0; index < cnode->inputs().size(); ++index) { + if (!FindParameter(cnode->input(index), func_graph).first) { + continue; } + return FindParameter(cnode->input(index), func_graph); } + } + + if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data()) { + return std::make_pair(node, false); + } + + if (IsParallelCareNode(cnode)) { return std::make_pair(nullptr, false); - } else { - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsValueNode(cnode->input(0))) { - for (size_t index = 0; index < cnode->inputs().size(); ++index) { - if (!FindParameter(cnode->input(index), func_graph).first) { - continue; - } - return FindParameter(cnode->input(index), func_graph); - } - } else { - if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data()) { - return std::make_pair(node, false); - } - if (IsParallelCareNode(cnode)) { - return std::make_pair(nullptr, false); - } else { - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - for (size_t index = 0; index < cnode->inputs().size(); ++index) { - PrimitivePtr prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - if ((prim->name() == DEPEND || prim->name() == LOAD) && index != 1) { - continue; - } - if (!FindParameter(cnode->input(index), func_graph).first) { - continue; - } - return FindParameter(cnode->input(index), func_graph); - } - } + } + + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + for (size_t index = 0; index < cnode->inputs().size(); ++index) { + PrimitivePtr prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + if ((prim->name() == DEPEND || prim->name() == LOAD) && index != 1) { + continue; + } + if (!FindParameter(cnode->input(index), func_graph).first) { + continue; } + return FindParameter(cnode->input(index), func_graph); } return std::make_pair(nullptr, false); } @@ -1101,6 +1108,25 @@ static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &par MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; } +static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) { + if ((node->inputs().size() == 2) && (IsValueNode(node->input(1)))) { + MS_LOG(INFO) << "Input is ValueList, skip it."; + return false; + } + + if ((node->inputs().size() == 2) && + (AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) { + MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node"; + return false; + } + + if (mirror_ops.size() != node_size - 1) { + MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is " + << node_size - 1; + } + return true; +} + void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); size_t node_size = node->inputs().size(); @@ -1113,21 +1139,11 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons node_size--; } } - if ((node->inputs().size() == 2) && (IsValueNode(node->input(1)))) { - MS_LOG(INFO) << "Input is ValueList, skip it."; - return; - } - if ((node->inputs().size() == 2) && - (AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) { - MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node"; + if (!CheckInsertMirrorOps(mirror_ops, node, node_size)) { return; } - if (mirror_ops.size() != node_size - 1) { - MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is " - << node_size - 1; - } for (size_t index = 1; index < node_size; ++index) { OperatorVector backward_op = mirror_ops[index - 1]; if (backward_op.empty()) { @@ -1181,15 +1197,15 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons // pipeline mirror would not be set, which should be supported later AddCommOpFusionType(comm_op, param_node_pair.first); } - } else { - for (auto &op : backward_op) { - AnfNodePtr pre_node = node->input(index); - InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name); - auto comm_op = node->input(index)->cast(); - // add fusion flag - // pipeline mirror would not be set, which should be supported later - AddCommOpFusionType(comm_op, param_node_pair.first); - } + continue; + } + for (auto &op : backward_op) { + AnfNodePtr pre_node = node->input(index); + InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name); + auto comm_op = node->input(index)->cast(); + // add fusion flag + // pipeline mirror would not be set, which should be supported later + AddCommOpFusionType(comm_op, param_node_pair.first); } } } @@ -1849,13 +1865,29 @@ void SetLastNodeStrategy(const StrategyPtr strategyPtr) { strategyPtr->ResetInputs(strategys); } +static bool CheckExtractInfomation(const CNodePtr &cnode) { + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + return false; + } + + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = GetValueNode(prim_anf_node); + if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) { + return false; + } + + if (!IsParallelCareNode(cnode)) { + return false; + } + return true; +} + void ExtractInformation(const std::vector &all_nodes, bool is_training) { // load strategy map from checkpoint StrategyMap stra_map; - if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { - if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { - MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; - } + if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() && + (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) { + MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; } vector last_forward_node_ids; if (!is_training) { @@ -1865,76 +1897,71 @@ void ExtractInformation(const std::vector &all_nodes, bool is_traini for (auto &node : all_nodes) { auto cnode = node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + if (!CheckExtractInfomation(cnode)) { continue; } + SetVirtualDatasetStrategy(cnode); ValueNodePtr prim_anf_node = cnode->input(0)->cast(); PrimitivePtr prim = GetValueNode(prim_anf_node); - if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST || prim->name() == RECEIVE) { - continue; - } + auto attrs = prim->attrs(); MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name(); - if (IsParallelCareNode(cnode)) { - std::vector shape_list = ExtractShape(cnode); - if (shape_list.empty()) { - MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; - } - OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); - if (operator_ == nullptr) { - MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->name() << " OperatorInstance failed"; - } - auto &inputs = cnode->inputs(); - std::vector input_value; - for (size_t index = 1; index < inputs.size(); ++index) { - if (inputs[index]->isa()) { - input_value.push_back(GetValueNode(inputs[index])); - } else { - input_value.emplace_back(nullptr); - } - } - StrategyPtr strategyPtr = nullptr; - (*operator_).set_input_value(input_value); - (*operator_).set_outputs_dtype(cnode->Type()); - (*operator_).set_cnode(cnode); - if (prim->name() == RESHAPE) { - cnode->set_user_data(operator_); + + std::vector shape_list = ExtractShape(cnode); + if (shape_list.empty()) { + MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; + } + OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); + MS_EXCEPTION_IF_NULL(operator_); + + auto &inputs = cnode->inputs(); + std::vector input_value; + for (size_t index = 1; index < inputs.size(); ++index) { + if (inputs[index]->isa()) { + input_value.push_back(GetValueNode(inputs[index])); continue; } - // load strategy checkpoint - // key of strategy map - std::string strategy_key_name = ""; - auto param_names = NodeParameterName(cnode); - if (!param_names.empty()) { - strategy_key_name = prim->name() + "_" + param_names[0].first; - } - bool load_strategy_from_ckpt = - StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); - bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != - last_forward_node_ids.end(); - bool full_batch = ParallelContext::GetInstance()->full_batch(); - if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) { - MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() - << " is empty, using batch parallel"; - strategyPtr = GenerateBatchParallelStrategy(operator_, prim); - } else if (StrategyFound(attrs)) { - strategyPtr = ExtractStrategy(attrs); - } else { - strategyPtr = stra_map[strategy_key_name]; - } - if (strategyPtr != nullptr) { - if (is_last_nodes && full_batch) { - SetLastNodeStrategy(strategyPtr); - } - if (operator_->Init(strategyPtr) == FAILED) { - MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; - } - cnode->set_user_data(operator_); - } else { - MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; - } + input_value.emplace_back(nullptr); + } + StrategyPtr strategyPtr = nullptr; + (*operator_).set_input_value(input_value); + (*operator_).set_outputs_dtype(cnode->Type()); + (*operator_).set_cnode(cnode); + if (prim->name() == RESHAPE) { + cnode->set_user_data(operator_); + continue; + } + // load strategy checkpoint + // key of strategy map + std::string strategy_key_name = ""; + auto param_names = NodeParameterName(cnode); + if (!param_names.empty()) { + strategy_key_name = prim->name() + "_" + param_names[0].first; + } + bool load_strategy_from_ckpt = + StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); + bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != + last_forward_node_ids.end(); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) { + MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() + << " is empty, using batch parallel"; + strategyPtr = GenerateBatchParallelStrategy(operator_, prim); + } else if (StrategyFound(attrs)) { + strategyPtr = ExtractStrategy(attrs); + } else { + strategyPtr = stra_map[strategy_key_name]; + } + + MS_EXCEPTION_IF_NULL(strategyPtr); + if (is_last_nodes && full_batch) { + SetLastNodeStrategy(strategyPtr); + } + if (operator_->Init(strategyPtr) == FAILED) { + MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; } + cnode->set_user_data(operator_); } } @@ -1994,9 +2021,9 @@ std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, si MS_EXCEPTION_IF_NULL(cnode); OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); MS_EXCEPTION_IF_NULL(distribute_operator); - if (distribute_operator->outputs_tensor_info().size() < output_index) { + if (distribute_operator->outputs_tensor_info().size() <= output_index) { MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->inputs_tensor_info().size() - << ", must be less than output_index " << output_index; + << ", must be greater than output_index " << output_index; } TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index]; TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();