diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc index e88868c772..87b8d15cca 100644 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc @@ -204,7 +204,7 @@ Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) { PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { + if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; } @@ -215,7 +215,8 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { } auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); - if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { + MS_EXCEPTION_IF_NULL(dropout_gen_mask_cnode); + if (dropout_gen_mask_cnode->inputs().size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; } if (!IsValueNode(dropout_gen_mask_cnode->input(0))) { @@ -232,45 +233,11 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { return prim; } -void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) { - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; - } - - AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX); - MS_EXCEPTION_IF_NULL(dropout_gen_mask); - if (!dropout_gen_mask->isa()) { - MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode."; - } - - auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); - if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; - } - - if (!IsValueNode(dropout_gen_mask_cnode->input(1))) { - MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple."; - } - - FuncGraphPtr func_graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - if (manager == nullptr) { - MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr."; - } - - ValuePtr new_shape = MakeValue(input_slice_shape); - AnfNodePtr val = NewValueNode(new_shape); - (void)manager->Replace(dropout_gen_mask_cnode->input(1), val); -} - // DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is // split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape // of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation // and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask. -std::vector DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { - std::vector replace_ops; +Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); PrimitivePtr prim = GetDropoutGenMaskPrim(cnode); MS_EXCEPTION_IF_NULL(prim); @@ -293,20 +260,15 @@ std::vector DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodeP if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) { MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1"; } - - Shape input_slice_shape = inputs_tensor_info_[0].slice_shape(); int32_t seed_0 = GetValue(attr[SEED0]); int32_t seed_1 = GetValue(attr[SEED1]); if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) { seed_0 = SEED_NUM; seed_1 = SEED_NUM; SEED_NUM++; - } else { - SetGenMaskShape(cnode, input_slice_shape); - MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape); - return replace_ops; } + Shape input_slice_shape = inputs_tensor_info_[0].slice_shape(); ValuePtr new_shape = MakeValue(input_slice_shape); Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0)); Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1)); @@ -316,8 +278,7 @@ std::vector DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodeP OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)}; OperatorArgs args = std::make_pair(attrs, params); Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)}; - replace_ops.push_back(replace_op); - return replace_ops; + return replace_op; } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h index c51a0a9513..c0d112f52d 100644 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h +++ b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h @@ -41,7 +41,7 @@ class DropoutDoMaskInfo : public OperatorInfo { Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override; std::shared_ptr>> GenerateBatchStrategies() override; - std::vector GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); + Operator GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); protected: Status CheckStrategy(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 39dd2c96e0..4528ff8639 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -1876,15 +1876,11 @@ void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePt DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast(distribute_operator); MS_EXCEPTION_IF_NULL(dropout_do_mask); - std::vector replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode); - if (replace_op.empty()) { - MS_LOG(DEBUG) << "No need to replace dropout_gen_mask"; - return; - } + Operator replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode); if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; } - ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast()); + ReplaceOneOp(replace_op, cnode->input(DROPOUT_GEN_MASK_INDEX)->cast()); } void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {