From 4109308e3499ba9a00d68c006b1d1e2a5b68d2ea Mon Sep 17 00:00:00 2001 From: Ziyan Date: Wed, 10 Mar 2021 16:19:35 +0800 Subject: [PATCH] insert parallel optimizer once --- .../ccsrc/frontend/parallel/step_parallel.cc | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index fe011ddc4b..7e8b44400a 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1571,7 +1571,7 @@ std::pair FindSubGraph(const FuncGraphPtr &graph, const Anf } static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair &res, - const AnfNodePtr &node) { + const AnfNodePtr &node, const std::string &op_name) { MS_EXCEPTION_IF_NULL(res.first); MS_EXCEPTION_IF_NULL(node); auto cnode = res.first->cast(); @@ -1579,10 +1579,9 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group MS_EXCEPTION_IF_NULL(graph); auto cnode_prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(cnode_prim); - int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); Operator op; CNodePtr allgather; - if (grad_accumulation_step > 1) { + if (op_name == MINI_STEP_ALL_GATHER) { op = CreateMiniStepAllGatherOp(group); auto param_name = node->cast()->name(); if (cnode_prim->name() == CAST) { @@ -1613,21 +1612,41 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & } FuncGraphManagerPtr manager = root->manager(); MS_EXCEPTION_IF_NULL(manager); + int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); + std::string op_name; + if (grad_accumulation_step > 1) { + op_name = MINI_STEP_ALL_GATHER; + } else { + op_name = ALL_GATHER; + } auto param_sub_set = manager->node_users()[parameter]; + bool insert_flag = false; for (auto ¶m_pair : param_sub_set) { auto cnode = param_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->in_forward_flag()) { OperatorInfoPtr distribute_operator = cnode->user_data(); if (distribute_operator == nullptr) { - MS_LOG(WARNING) << "Parallel optimizer: " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; + MS_LOG(WARNING) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr"; } else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) { MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " << distribute_operator->inputs_tensor_info().size(); } - // insert allgather operator between shard parameter and cnode - InsertAllGatherOp(root, opt_shard_group, param_pair, parameter); - MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString(); + if (insert_flag) { + auto next_cnode = FindCNode(parameter, op_name, cnode->func_graph()); + if (next_cnode.first) { + manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second); + MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " + << GetPrimName(cnode); + continue; + } + } else { + // insert allgather operator between shard parameter and cnode + InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name); + MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " + << GetPrimName(cnode); + insert_flag = true; + } } } }