!13104 insert parallel optimizer once

From: @gong_zi_yan
Reviewed-by: @stsuteng,@kisnwang
Signed-off-by: @stsuteng
pull/13104/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 85e5fed534

@ -1571,7 +1571,7 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf
} }
static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res, static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res,
const AnfNodePtr &node) { const AnfNodePtr &node, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(res.first); MS_EXCEPTION_IF_NULL(res.first);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto cnode = res.first->cast<CNodePtr>(); auto cnode = res.first->cast<CNodePtr>();
@ -1579,10 +1579,9 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0)); auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(cnode_prim); MS_EXCEPTION_IF_NULL(cnode_prim);
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
Operator op; Operator op;
CNodePtr allgather; CNodePtr allgather;
if (grad_accumulation_step > 1) { if (op_name == MINI_STEP_ALL_GATHER) {
op = CreateMiniStepAllGatherOp(group); op = CreateMiniStepAllGatherOp(group);
auto param_name = node->cast<ParameterPtr>()->name(); auto param_name = node->cast<ParameterPtr>()->name();
if (cnode_prim->name() == CAST) { if (cnode_prim->name() == CAST) {
@ -1613,21 +1612,41 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
} }
FuncGraphManagerPtr manager = root->manager(); FuncGraphManagerPtr manager = root->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
std::string op_name;
if (grad_accumulation_step > 1) {
op_name = MINI_STEP_ALL_GATHER;
} else {
op_name = ALL_GATHER;
}
auto param_sub_set = manager->node_users()[parameter]; auto param_sub_set = manager->node_users()[parameter];
bool insert_flag = false;
for (auto &param_pair : param_sub_set) { for (auto &param_pair : param_sub_set) {
auto cnode = param_pair.first->cast<CNodePtr>(); auto cnode = param_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (cnode->in_forward_flag()) { if (cnode->in_forward_flag()) {
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>(); OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
if (distribute_operator == nullptr) { if (distribute_operator == nullptr) {
MS_LOG(WARNING) << "Parallel optimizer: " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; MS_LOG(WARNING) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr";
} else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) { } else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is "
<< distribute_operator->inputs_tensor_info().size(); << distribute_operator->inputs_tensor_info().size();
} }
if (insert_flag) {
auto next_cnode = FindCNode(parameter, op_name, cnode->func_graph());
if (next_cnode.first) {
manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second);
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
<< GetPrimName(cnode);
continue;
}
} else {
// insert allgather operator between shard parameter and cnode // insert allgather operator between shard parameter and cnode
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter); InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name);
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString(); MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
<< GetPrimName(cnode);
insert_flag = true;
}
} }
} }
} }

Loading…
Cancel
Save