From bee57fda66f4662a75a4f6efed4cde8e8b92d4c3 Mon Sep 17 00:00:00 2001 From: hongxing Date: Mon, 8 Jun 2020 17:04:32 +0200 Subject: [PATCH] support GatherV2 + Depend --- .../rec_core/rec_generate_strategy.cc | 64 ++++++++++++++----- .../rec_core/rec_generate_strategy.h | 15 +++-- .../auto_parallel/rec_core/rec_graph.h | 3 +- .../auto_parallel/rec_core/rec_parse_graph.cc | 2 +- .../auto_parallel/rec_core/rec_parse_graph.h | 7 +- .../auto_parallel/rec_core/rec_partition.cc | 32 ++++------ .../ccsrc/parallel/step_auto_parallel.cc | 36 +++++++++-- mindspore/ccsrc/parallel/step_auto_parallel.h | 2 + 8 files changed, 109 insertions(+), 52 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 6cad20b568..4bc183b1a2 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -135,24 +135,51 @@ std::vector> PreparePReLU(const std::shared_ptr &gra return strategies; } -std::vector> PrepareBiasAdd(std::vector s) { +std::vector> PrepareBatchNorm(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + std::vector> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); + for (size_t i = 1; i < strategies.size(); i++) { + strategies[i][0] = strategies[0][1]; + } + strategies[1][0] = 1; + return strategies; +} + +std::vector> PrepareSoftmaxWithLogits(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + std::vector> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = graph->nodes[iter_graph].tensor_parm.tensor_str.str_h; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_c; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = graph->nodes[iter_graph].tensor_parm.tensor_str.str_n; + return strategies; +} + +std::vector> PrepareBiasAdd(const std::shared_ptr> &s) { std::vector> strategies; - strategies.push_back(s); + strategies.push_back(*s); std::vector s_biasadd; - s_biasadd.push_back(s[1]); + s_biasadd.push_back(s->at(1)); strategies.push_back(s_biasadd); return strategies; } -std::vector> PrepareOneHot(std::vector s) { +std::vector> PrepareOneHot(const std::shared_ptr> &s) { std::vector> strategies; std::vector s_empty = {}; - strategies.push_back(s); + strategies.push_back(*s); strategies.push_back(s_empty); strategies.push_back(s_empty); return strategies; } +std::vector> PrepareGatherV2(const std::shared_ptr> &s) { + std::vector> strategies; + strategies.push_back(*s); + return strategies; +} + std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops) { @@ -270,6 +297,12 @@ std::vector> PrepareStrategy(const std::shared_ptr & return PrepareMatMul(graph, ops, iter_graph, iter_ops); } else if (type == PRELU) { return PreparePReLU(graph, ops, iter_graph, iter_ops); + } else if (type == BATCH_NORM) { + return PrepareBatchNorm(graph, ops, iter_graph, iter_ops); + } else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { + return PrepareSoftmaxWithLogits(graph, ops, iter_graph, iter_ops); + } else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { + return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); } else { return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); } @@ -336,7 +369,7 @@ std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr PrepareIncomingOperatorInputStrategy(const std::vector> &ops, const size_t incoming_op_index) { std::vector s; - if (ops[incoming_op_index]->type() == RESHAPE) { + if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2) { return s; } auto strategy = ops[incoming_op_index]->selected_strategy(); @@ -456,11 +489,6 @@ std::vector ModifyStrategyIfReduceIncoming(const std::vector ModifyStrategyIfSoftmaxIncoming(std::vector s) { - s.pop_back(); - return s; -} - std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, const size_t iter_ops, const size_t incoming_op_index) { std::vector s; @@ -474,9 +502,6 @@ std::vector CopyIncomingOperatorInputStrategy(const std::vectortype() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); } - if (ops[incoming_op_index]->type() == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { - s = ModifyStrategyIfSoftmaxIncoming(s); - } } return s; } @@ -496,11 +521,15 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect return stra; } + auto s_ptr = std::make_shared>(basic_stra); if (ops[iter_ops]->type() == BIAS_ADD) { - return PrepareBiasAdd(basic_stra); + return PrepareBiasAdd(s_ptr); } if (ops[iter_ops]->type() == ONEHOT) { - return PrepareOneHot(basic_stra); + return PrepareOneHot(s_ptr); + } + if (ops[iter_ops]->type() == GATHERV2) { + return PrepareGatherV2(s_ptr); } for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); @@ -599,7 +628,8 @@ std::vector CopyOutgoingOperatorInputStrategy(const std::vector s; if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN || - ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE) { + ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE || + ops[iter_ops]->type() == GATHERV2) { return s; } diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h index acda2b8452..2b76c59728 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -37,8 +37,15 @@ std::vector> PrepareMatMul(const std::shared_ptr &gr std::vector> PreparePReLU(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareBiasAdd(std::vector s); -std::vector> PrepareOneHot(std::vector s); +std::vector> PrepareBatchNorm(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +std::vector> PrepareSoftmaxWithLogits(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +std::vector> PrepareBiasAdd(const std::shared_ptr> &s); +std::vector> PrepareOneHot(const std::shared_ptr> &s); +std::vector> PrepareGatherV2(const std::shared_ptr> &s); std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); @@ -64,11 +71,11 @@ std::vector ModifyStrategyIfSqueezeIncoming(const std::vector GetDimList(const std::vector> &ops, const size_t iter_ops); std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, const size_t incoming_op_index, std::vector s); -std::vector ModifyStrategyIfSoftmaxIncoming(std::vector s); std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, const size_t iter_ops, const size_t incoming_op_index); std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, - const size_t iter_ops, std::vector s); + const size_t iter_ops, + std::vector basic_stra); void GenerateEliminatedOperatorStrategyForward(std::shared_ptr graph, const std::vector> &ops, const std::vector> &input_tensor_names, diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h index 9fcb6e5f69..879e22cb1f 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h @@ -48,7 +48,8 @@ enum OperatorType { kRecSqueeze, kRecCast, kRecReduce, - kRecPReLU + kRecPReLU, + kRecGatherV2 }; enum InfoType { kApplication, kConstant }; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index add0f5553e..979f987225 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -199,7 +199,7 @@ std::shared_ptr EliminateGraph(const std::shared_ptr graph, OperatorType::kRecOneHot, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, - OperatorType::kRecCast, OperatorType::kRecReshape}; + OperatorType::kRecCast, OperatorType::kRecReshape, OperatorType::kRecGatherV2}; for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { auto type = graph->nodes[node_index].apply.op_type; if (type_list.find(type) != type_list.end()) { diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index 34df09cb99..1b51e4d9b0 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -46,6 +46,7 @@ const std::map DictOpType{ {REDUCE_MAX, OperatorType::kRecReduce}, {REDUCE_MIN, OperatorType::kRecReduce}, {REDUCE_MEAN, OperatorType::kRecReduce}, + {GATHERV2, OperatorType::kRecGatherV2}, {RELU, OperatorType::kRecReLU}, {"ReLU6", OperatorType::kRecReLU}, @@ -63,9 +64,9 @@ const std::map DictOpType{ {MUL, OperatorType::kRecElmWiseOp}, {DIV, OperatorType::kRecElmWiseOp}, {REAL_DIV, OperatorType::kRecElmWiseOp}, - {SOFTMAX, OperatorType::kRecElmWiseOp}, - {LOG_SOFTMAX, OperatorType::kRecElmWiseOp}, - {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecElmWiseOp}, + {SOFTMAX, OperatorType::kRecSoftmax}, + {LOG_SOFTMAX, OperatorType::kRecSoftmax}, + {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmax}, {SQRT, OperatorType::kRecElmWiseOp}, {NEG, OperatorType::kRecElmWiseOp}, {POW, OperatorType::kRecElmWiseOp}, diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc index c61da7f16f..a176a997ae 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc @@ -53,9 +53,8 @@ double GetWeights(const Graph::NodeType &node) { auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecReLU || op.op_type == OperatorType::kRecSoftmax || - op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { - // For Activation and Softmax + } else if (op.op_type == OperatorType::kRecReLU) { + // For Activation auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); @@ -69,11 +68,6 @@ double GetWeights(const Graph::NodeType &node) { auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecBatchNorm) { - // For BatchNorm - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(op); } else if (op.op_type == OperatorType::kRecOneHot || op.op_type == OperatorType::kRecLog || op.op_type == OperatorType::kRecExp || op.op_type == OperatorType::kRecAdd || op.op_type == OperatorType::kRecSub || op.op_type == OperatorType::kRecMul || @@ -83,8 +77,10 @@ double GetWeights(const Graph::NodeType &node) { auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU) { - // For unknown type + } else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU || + op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecSoftmax || + op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { + // For unprocessed type return 0.0; } else { MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed."; @@ -147,9 +143,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node, auto cost_ptr = std::make_shared(); return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecReLU || node.apply.op_type == OperatorType::kRecSoftmax || - node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { - // For Softmax & Activation + } else if (node.apply.op_type == OperatorType::kRecReLU) { + // For Activation auto cost_ptr = std::make_shared(); return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); @@ -162,11 +157,6 @@ StrategyRec PartitionNode(const Graph::NodeType &node, // For BiasAdd auto cost_ptr = std::make_shared(); - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecBatchNorm) { - // For BatchNorm - auto cost_ptr = std::make_shared(); - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); } else if (node.apply.op_type == OperatorType::kRecOneHot || node.apply.op_type == OperatorType::kRecLog || node.apply.op_type == OperatorType::kRecExp || node.apply.op_type == OperatorType::kRecAdd || @@ -177,8 +167,10 @@ StrategyRec PartitionNode(const Graph::NodeType &node, auto cost_ptr = std::make_shared(); return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU) { - // For unknown type + } else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU || + node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecSoftmax || + node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { + // For unprocessed type StrategyRec default_strategy; return default_strategy; } else { diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index fc6ca25e4a..abd13a8a03 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -465,9 +465,11 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node ValueNodePtr prim_anf_node = cnode->input(0)->cast(); if (!IsAutoParallelCareNode(cnode)) { // Needed by rec_parser - PrimitivePtr prim = GetValueNode(prim_anf_node); - if (prim->name() == TUPLE_GETITEM) { - entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), cnode->input(1)->UniqueId())); + if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) { + auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node); + if (prev_cnode != nullptr) { + entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId())); + } } continue; } @@ -528,9 +530,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no ValueNodePtr prim_anf_node = cnode->input(0)->cast(); if (!IsAutoParallelCareNode(cnode)) { // Needed by rec_parser - PrimitivePtr prim = GetValueNode(prim_anf_node); - if (prim->name() == TUPLE_GETITEM) { - entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), cnode->input(1)->UniqueId())); + if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) { + auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node); + if (prev_cnode != nullptr) { + entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId())); + } } continue; } @@ -1155,6 +1159,26 @@ std::vector> RecInputTensorNames(const std::map(prim_anf_node); + if (prim->name() == TUPLE_GETITEM || prim->name() == DEPEND) { + auto prev_cnode = cnode->input(1)->cast(); + if (prev_cnode == nullptr || !IsValueNode(prev_cnode->input(0))) { + return nullptr; + } + auto prev_prim = prev_cnode->input(0)->cast()->value()->cast(); + while (prev_prim->name() == TUPLE_GETITEM || prev_prim->name() == DEPEND) { + prev_cnode = prev_cnode->input(1)->cast(); + if (prev_cnode == nullptr || !IsValueNode(prev_cnode->input(0))) { + return nullptr; + } + prev_prim = prev_cnode->input(0)->cast()->value()->cast(); + } + return prev_cnode; + } + return nullptr; +} + Status ParallelStrategyRecSearch(const std::vector &all_nodes, const FuncGraphPtr &root) { if (CostModelContext::GetInstance()->is_multi_subgraphs()) { if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.h b/mindspore/ccsrc/parallel/step_auto_parallel.h index fff9dfa4c3..c923e5770f 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.h +++ b/mindspore/ccsrc/parallel/step_auto_parallel.h @@ -57,6 +57,8 @@ Status ParallelStrategyRecSearch(const std::vector &all_nodes, const std::vector> RecInputTensorNames(const std::map::iterator &it, std::vector> input_tensor_names); + +CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node); } // namespace parallel } // namespace mindspore #endif // PARALLEL_STEP_AUTO_PARALLEL_H_