From 66553ac3c5ada6a370964994716373cf426f11b1 Mon Sep 17 00:00:00 2001 From: hongxing Date: Wed, 8 Jul 2020 09:44:17 +0200 Subject: [PATCH] optimize code --- .../auto_parallel/rec_core/rec_generate_strategy.cc | 3 +-- .../auto_parallel/rec_core/rec_parse_graph.cc | 12 +----------- .../auto_parallel/rec_core/rec_parse_graph.h | 7 +++++++ 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 9de71231c0..828523fed1 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -168,12 +168,11 @@ std::vector> PrepareGatherV2(const std::vector s) { std::vector> strategies; - int32_t axis = 0; auto axis_input = GetValue(ops[iter_ops]->input_value().at(2)); if (axis_input < 0) { axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); } - axis = axis_input; + int32_t axis = axis_input; if (axis >= SizeToInt(s.size())) { MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range."; } diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index c0412e9108..0e6a3411e3 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -20,7 +20,6 @@ #include #include #include -#include #include "ir/value.h" #include "parallel/auto_parallel/rec_core/rec_graph.h" @@ -215,23 +214,16 @@ std::shared_ptr EliminateGraph(const std::shared_ptr &graph, const std::shared_ptr>> &eli_list, const std::shared_ptr> &index_list) { MS_EXCEPTION_IF_NULL(graph); - static const std::set elementwise_type = { - OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, - OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, - OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast, - OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue}; for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { auto type = graph->nodes[node_index].apply.op_type; - if (elementwise_type.find(type) != elementwise_type.end()) { + if (ElementWiseOpType.find(type) != ElementWiseOpType.end()) { Eliminate_Aux(node_index, graph, eli_list); } } - index_list->reserve(graph->nodes.size()); for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { index_list->push_back(i); } - for (size_t i = 0; i < (size_t)eli_list->size(); i++) { if (eli_list->at(i)[0] >= index_list->size()) { MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; @@ -241,13 +233,11 @@ std::shared_ptr EliminateGraph(const std::shared_ptr &graph, index_list->at(j)--; } } - std::shared_ptr new_graph(new Graph); for (size_t i = 0; i < graph->nodes.size(); i++) { if (index_list->at(i) > SIZE_MAX / 2) { continue; } - new_graph->nodes.push_back(graph->nodes[i]); auto *node_in = &new_graph->nodes[index_list->at(i)].node_in; for (size_t j = node_in->size(); j > 0; j--) { diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index 53abefd1c8..c05c7d33b8 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -22,12 +22,19 @@ #include #include #include +#include #include "parallel/auto_parallel/rec_core/rec_graph.h" #include "parallel/ops_info/operator_info.h" namespace mindspore { namespace parallel { +static const std::set ElementWiseOpType = { + OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, + OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, + OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast, + OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue}; + const std::map DictOpType{ {MATMUL, OperatorType::kRecMatMul}, {CONV2D, OperatorType::kRecConvolution},