!2939 [CleanCode] codedex 20200708

Merge pull request !2939 from Chong/zc
pull/2939/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 1cbf6c5d0b

@ -168,12 +168,11 @@ std::vector<std::vector<int32_t>> PrepareGatherV2(const std::vector<std::shared_
const size_t iter_ops, std::vector<int32_t> s) { const size_t iter_ops, std::vector<int32_t> s) {
std::vector<std::vector<int32_t>> strategies; std::vector<std::vector<int32_t>> strategies;
int32_t axis = 0;
auto axis_input = GetValue<int>(ops[iter_ops]->input_value().at(2)); auto axis_input = GetValue<int>(ops[iter_ops]->input_value().at(2));
if (axis_input < 0) { if (axis_input < 0) {
axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
} }
axis = axis_input; int32_t axis = axis_input;
if (axis >= SizeToInt(s.size())) { if (axis >= SizeToInt(s.size())) {
MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range."; MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range.";
} }

@ -20,7 +20,6 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <set>
#include "ir/value.h" #include "ir/value.h"
#include "parallel/auto_parallel/rec_core/rec_graph.h" #include "parallel/auto_parallel/rec_core/rec_graph.h"
@ -215,23 +214,16 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::shared_ptr<std::vector<size_t>> &index_list) { const std::shared_ptr<std::vector<size_t>> &index_list) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
static const std::set<OperatorType> elementwise_type = {
OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd,
OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul,
OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast,
OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue};
for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
auto type = graph->nodes[node_index].apply.op_type; auto type = graph->nodes[node_index].apply.op_type;
if (elementwise_type.find(type) != elementwise_type.end()) { if (ElementWiseOpType.find(type) != ElementWiseOpType.end()) {
Eliminate_Aux(node_index, graph, eli_list); Eliminate_Aux(node_index, graph, eli_list);
} }
} }
index_list->reserve(graph->nodes.size()); index_list->reserve(graph->nodes.size());
for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) {
index_list->push_back(i); index_list->push_back(i);
} }
for (size_t i = 0; i < (size_t)eli_list->size(); i++) { for (size_t i = 0; i < (size_t)eli_list->size(); i++) {
if (eli_list->at(i)[0] >= index_list->size()) { if (eli_list->at(i)[0] >= index_list->size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
@ -241,13 +233,11 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
index_list->at(j)--; index_list->at(j)--;
} }
} }
std::shared_ptr<Graph> new_graph(new Graph); std::shared_ptr<Graph> new_graph(new Graph);
for (size_t i = 0; i < graph->nodes.size(); i++) { for (size_t i = 0; i < graph->nodes.size(); i++) {
if (index_list->at(i) > SIZE_MAX / 2) { if (index_list->at(i) > SIZE_MAX / 2) {
continue; continue;
} }
new_graph->nodes.push_back(graph->nodes[i]); new_graph->nodes.push_back(graph->nodes[i]);
auto *node_in = &new_graph->nodes[index_list->at(i)].node_in; auto *node_in = &new_graph->nodes[index_list->at(i)].node_in;
for (size_t j = node_in->size(); j > 0; j--) { for (size_t j = node_in->size(); j > 0; j--) {

@ -22,12 +22,19 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <set>
#include "parallel/auto_parallel/rec_core/rec_graph.h" #include "parallel/auto_parallel/rec_core/rec_graph.h"
#include "parallel/ops_info/operator_info.h" #include "parallel/ops_info/operator_info.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
static const std::set<OperatorType> ElementWiseOpType = {
OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd,
OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul,
OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast,
OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue};
const std::map<std::string, OperatorType> DictOpType{ const std::map<std::string, OperatorType> DictOpType{
{MATMUL, OperatorType::kRecMatMul}, {MATMUL, OperatorType::kRecMatMul},
{CONV2D, OperatorType::kRecConvolution}, {CONV2D, OperatorType::kRecConvolution},

Loading…
Cancel
Save