|
|
@ -1390,9 +1390,16 @@ static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodeP
|
|
|
|
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
|
|
|
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
|
|
|
allgather = cnode->input(res.second)->cast<CNodePtr>();
|
|
|
|
allgather = cnode->input(res.second)->cast<CNodePtr>();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// add fusion flag
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(allgather);
|
|
|
|
MS_EXCEPTION_IF_NULL(allgather);
|
|
|
|
|
|
|
|
// add fusion flag
|
|
|
|
AddCommOpFusionType(allgather, parameter);
|
|
|
|
AddCommOpFusionType(allgather, parameter);
|
|
|
|
|
|
|
|
// add gradients mean
|
|
|
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
|
|
|
|
|
|
|
|
auto attrs = prim->attrs();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
|
|
|
|
|
|
|
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
|
|
|
|
|
|
|
|
attrs["mean_flag"] = MakeValue<bool>(mean_flag);
|
|
|
|
|
|
|
|
prim->SetAttrs(attrs);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
|
|
|
|
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
|
|
|
|