|
|
|
@ -294,6 +294,12 @@ class _AutoParallelContext:
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError('indices must be a python list')
|
|
|
|
|
|
|
|
|
|
if len(set(indices)) != len(indices):
|
|
|
|
|
raise ValueError('indices has duplicate elements')
|
|
|
|
|
|
|
|
|
|
if sorted(indices) != indices:
|
|
|
|
|
raise ValueError('elements in indices must be sorted in ascending order')
|
|
|
|
|
|
|
|
|
|
if isinstance(group, (str)):
|
|
|
|
|
group_len = len(group)
|
|
|
|
|
if group_len > _MAX_GROUP_NAME_LEN:
|
|
|
|
@ -308,7 +314,7 @@ class _AutoParallelContext:
|
|
|
|
|
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
|
|
|
|
|
|
|
|
|
self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
|
|
|
|
|
if context.get_context("device_target") == "Ascend":
|
|
|
|
|
if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
|
|
|
|
|
_set_fusion_strategy_by_idx(indices)
|
|
|
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_indices(self, group=""):
|
|
|
|
|