!7201 [AutoParallel]Add Check for allreduce fusion

Merge pull request !7201 from lichen/add_check_for_fusion
pull/7201/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 0b75b3d9fe

@ -1895,7 +1895,11 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(pre_node);
auto pre_cnode = pre_node->cast<CNodePtr>();
if (pre_cnode == nullptr) {
if (pre_cnode == nullptr || !IsValueNode<Primitive>(pre_cnode->input(0))) {
return loss_node_info;
}
if (!IsValueNode<Primitive>(pre_cnode->input(0))) {
MS_LOG(DEBUG) << "pre_cnode:" << pre_cnode->ToString();
return loss_node_info;
}
auto prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));

@ -294,6 +294,12 @@ class _AutoParallelContext:
else:
raise TypeError('indices must be a python list')
if len(set(indices)) != len(indices):
raise ValueError('indices has duplicate elements')
if sorted(indices) != indices:
raise ValueError('elements in indices must be sorted in ascending order')
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
@ -308,7 +314,7 @@ class _AutoParallelContext:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
if context.get_context("device_target") == "Ascend":
if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
_set_fusion_strategy_by_idx(indices)
def get_all_reduce_fusion_split_indices(self, group=""):

Loading…
Cancel
Save