!10743 enable gradients mean in opt shard

From: @gong_zi_yan
Reviewed-by: @stsuteng,@yao_yf,@kisnwang
Signed-off-by: @stsuteng
pull/10743/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b07dd76246

@ -1390,9 +1390,16 @@ static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodeP
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
allgather = cnode->input(res.second)->cast<CNodePtr>(); allgather = cnode->input(res.second)->cast<CNodePtr>();
} }
// add fusion flag
MS_EXCEPTION_IF_NULL(allgather); MS_EXCEPTION_IF_NULL(allgather);
// add fusion flag
AddCommOpFusionType(allgather, parameter); AddCommOpFusionType(allgather, parameter);
// add gradients mean
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
auto attrs = prim->attrs();
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
attrs["mean_flag"] = MakeValue<bool>(mean_flag);
prim->SetAttrs(attrs);
} }
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter, static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,

@ -134,6 +134,8 @@ def get_bprop_all_gather(self):
rank = get_rank(self.group) rank = get_rank(self.group)
dev_num = get_group_size(self.group) dev_num = get_group_size(self.group)
split = P.Split(output_num=dev_num) split = P.Split(output_num=dev_num)
mean_flag = self.get_attr_dict()["mean_flag"]
scale = 1/self.rank_size
def bprop(x, out, dout): def bprop(x, out, dout):
if fusion == 0: if fusion == 0:
@ -141,6 +143,8 @@ def get_bprop_all_gather(self):
else: else:
grad = all_reduce(dout) grad = all_reduce(dout)
dx = split(grad)[rank] dx = split(grad)[rank]
if mean_flag:
dx = F.tensor_mul(dx, scale)
return (dx,) return (dx,)
return bprop return bprop

Loading…
Cancel
Save