From 83e627dd5e363db5a9f2f6351c58e77d129418de Mon Sep 17 00:00:00 2001 From: sheng Date: Wed, 21 Oct 2020 19:09:33 +0200 Subject: [PATCH] Support Dropout stra; Add parser's input/output tensor num check --- .../auto_parallel/rec_core/rec_generate_strategy.cc | 2 +- .../parallel/auto_parallel/rec_core/rec_parse_graph.cc | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 37f2cad863..f9f177e9f5 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -416,7 +416,7 @@ Strategys PrepareStrategy(const std::shared_ptr &graph, const std::vector return PrepareMatMul(graph, ops, iter_graph, iter_ops); } else if (type == ONEHOT) { return PrepareOneHot(graph, ops, iter_graph, iter_ops); - } else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset")) { + } else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || (type == "Dropout")) { return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); } else { return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 3a52d0705d..41f790ecb4 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -53,6 +53,10 @@ Graph::NodeType MakeNewOperator(const std::vector> NewOp.apply.op_type = DictOpType.at(op_type); } + if (ops[iter_ops]->outputs_tensor_info().size() == 0) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " output tensor info is empty."; + } + if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { NewOp.tensor_parm = MakeTensor( ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], @@ -74,6 +78,10 @@ Graph::NodeType MakeNewOperator(const std::vector> OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor) { + if (ops[iter_ops]->inputs_tensor_info().size() > MAX_INPUT_NUM) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " input tensor num exceeds limit."; + } + for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size(); iter_input_tensors++) { if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) {